aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2016-03-17 17:46:03 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-03-18 08:47:32 -0700
commit50edc4e878c076c44fecc847af110a19b171eb63 (patch)
tree8ea27610ddefee566c2a3f750c2e034085af5b83 /tensorflow
parent1a39c2c1979706084338352d1264951b3ec9c6bc (diff)
TensorFlow: move eigen some NN code from our third_party/eigen3 copy
to being part of TF, add tests. Change: 117509710
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/core/kernels/BUILD53
-rw-r--r--tensorflow/core/kernels/attention_ops.cc2
-rw-r--r--tensorflow/core/kernels/avgpooling_op.cc2
-rw-r--r--tensorflow/core/kernels/avgpooling_op.h2
-rw-r--r--tensorflow/core/kernels/conv_2d.h3
-rw-r--r--tensorflow/core/kernels/eigen_activations.h125
-rw-r--r--tensorflow/core/kernels/eigen_activations_test.cc101
-rw-r--r--tensorflow/core/kernels/eigen_attention.h244
-rw-r--r--tensorflow/core/kernels/eigen_attention_test.cc107
-rw-r--r--tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h539
-rw-r--r--tensorflow/core/kernels/eigen_backward_spatial_convolutions.h359
-rw-r--r--tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc1959
-rw-r--r--tensorflow/core/kernels/eigen_cuboid_convolution.h195
-rw-r--r--tensorflow/core/kernels/eigen_patch_3d.h257
-rw-r--r--tensorflow/core/kernels/eigen_pooling.h441
-rw-r--r--tensorflow/core/kernels/eigen_pooling_test.cc742
-rw-r--r--tensorflow/core/kernels/eigen_softmax.h90
-rw-r--r--tensorflow/core/kernels/eigen_softmax_test.cc65
-rw-r--r--tensorflow/core/kernels/eigen_spatial_convolutions.h785
-rw-r--r--tensorflow/core/kernels/eigen_spatial_convolutions_test.cc1215
-rw-r--r--tensorflow/core/kernels/maxpooling_op.cc2
-rw-r--r--tensorflow/core/kernels/maxpooling_op.h2
-rw-r--r--tensorflow/core/kernels/maxpooling_op_gpu.h1
-rw-r--r--tensorflow/core/kernels/pooling_ops_common.h1
-rw-r--r--tensorflow/core/kernels/pooling_ops_common_gpu.h1
-rw-r--r--tensorflow/core/kernels/resize_nearest_neighbor_op_gpu.h1
26 files changed, 7283 insertions, 11 deletions
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index d4d9f2f22f..1d51656a48 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -54,6 +54,7 @@ cc_library(
name = "conv_2d",
hdrs = ["conv_2d.h"],
deps = [
+ ":eigen_helpers",
"//tensorflow/core:framework",
"//third_party/eigen3",
],
@@ -214,6 +215,24 @@ cc_header_only_library(
deps = [":bounds_check"],
)
+cc_library(
+ name = "eigen_helpers",
+ hdrs = [
+ "eigen_activations.h",
+ "eigen_attention.h",
+ "eigen_backward_cuboid_convolutions.h",
+ "eigen_backward_spatial_convolutions.h",
+ "eigen_cuboid_convolution.h",
+ "eigen_patch_3d.h",
+ "eigen_pooling.h",
+ "eigen_softmax.h",
+ "eigen_spatial_convolutions.h",
+ ],
+ deps = [
+ "//third_party/eigen3",
+ ],
+)
+
# OpKernel libraries ----------------------------------------------------------
tf_kernel_libraries(
@@ -529,12 +548,12 @@ tf_kernel_libraries(
name = "image",
prefixes = [
"adjust_contrast_op",
- "attention_ops",
"colorspace_op",
"decode_jpeg_op",
"decode_png_op",
"draw_bounding_box_op",
"encode_jpeg_op",
+ "attention_ops",
"encode_png_op",
"random_crop_op",
"resize_area_op",
@@ -544,6 +563,7 @@ tf_kernel_libraries(
"sample_distorted_bounding_box_op",
],
deps = [
+ ":eigen_helpers",
"//tensorflow/core:framework",
"//tensorflow/core:image_ops_op_lib",
"//tensorflow/core:lib",
@@ -556,6 +576,27 @@ tf_kernel_libraries(
tf_cc_tests(
linkstatic = tf_kernel_tests_linkstatic(), # Required for benchmarking
tests = [
+ "eigen_activations_test",
+ "eigen_attention_test",
+ "eigen_backward_spatial_convolutions_test",
+ "eigen_pooling_test",
+ "eigen_softmax_test",
+ "eigen_spatial_convolutions_test",
+ ],
+ deps = [
+ ":eigen_helpers",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
+tf_cc_tests(
+ linkstatic = tf_kernel_tests_linkstatic(), # Required for benchmarking
+ tests = [
"adjust_contrast_op_benchmark_test",
"adjust_contrast_op_test",
"colorspace_op_test",
@@ -820,6 +861,7 @@ tf_kernel_library(
],
deps = [
":conv_2d",
+ ":eigen_helpers",
":ops_util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
@@ -1029,6 +1071,15 @@ filegroup(
srcs = [
"avgpooling_op.h",
"bounds_check.h",
+ "eigen_activations.h",
+ "eigen_attention.h",
+ "eigen_backward_cuboid_convolutions.h",
+ "eigen_backward_spatial_convolutions.h",
+ "eigen_cuboid_convolution.h",
+ "eigen_patch_3d.h",
+ "eigen_pooling.h",
+ "eigen_softmax.h",
+ "eigen_spatial_convolutions.h",
"maxpooling_op.h",
"ops_util.cc",
"ops_util.h",
diff --git a/tensorflow/core/kernels/attention_ops.cc b/tensorflow/core/kernels/attention_ops.cc
index 59e147bf93..36c1b26476 100644
--- a/tensorflow/core/kernels/attention_ops.cc
+++ b/tensorflow/core/kernels/attention_ops.cc
@@ -18,12 +18,12 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include <vector>
-#include "third_party/eigen3/unsupported/Eigen/CXX11/NeuralNetworks"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/eigen_attention.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/core/kernels/avgpooling_op.cc b/tensorflow/core/kernels/avgpooling_op.cc
index 37c502ad69..a3c03601c8 100644
--- a/tensorflow/core/kernels/avgpooling_op.cc
+++ b/tensorflow/core/kernels/avgpooling_op.cc
@@ -20,13 +20,13 @@ limitations under the License.
#include "tensorflow/core/kernels/avgpooling_op.h"
#include <vector>
-#include "third_party/eigen3/unsupported/Eigen/CXX11/NeuralNetworks"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/kernels/eigen_pooling.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/kernels/pooling_ops_common.h"
#include "tensorflow/core/lib/core/errors.h"
diff --git a/tensorflow/core/kernels/avgpooling_op.h b/tensorflow/core/kernels/avgpooling_op.h
index 0b577971f3..2804cdbee5 100644
--- a/tensorflow/core/kernels/avgpooling_op.h
+++ b/tensorflow/core/kernels/avgpooling_op.h
@@ -17,8 +17,8 @@ limitations under the License.
#define TENSORFLOW_KERNELS_AVGPOOLING_OP_H_
// Functor definition for AvgPoolingOp, must be compilable by nvcc.
-#include "third_party/eigen3/unsupported/Eigen/CXX11/NeuralNetworks"
#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/eigen_pooling.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
diff --git a/tensorflow/core/kernels/conv_2d.h b/tensorflow/core/kernels/conv_2d.h
index 141343ec3b..9d06853053 100644
--- a/tensorflow/core/kernels/conv_2d.h
+++ b/tensorflow/core/kernels/conv_2d.h
@@ -16,9 +16,10 @@ limitations under the License.
#ifndef TENSORFLOW_KERNELS_CONV_2D_H_
#define TENSORFLOW_KERNELS_CONV_2D_H_
-#include "third_party/eigen3/unsupported/Eigen/CXX11/NeuralNetworks"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/eigen_backward_spatial_convolutions.h"
+#include "tensorflow/core/kernels/eigen_spatial_convolutions.h"
#include "tensorflow/core/util/tensor_format.h"
namespace tensorflow {
diff --git a/tensorflow/core/kernels/eigen_activations.h b/tensorflow/core/kernels/eigen_activations.h
new file mode 100644
index 0000000000..252e434811
--- /dev/null
+++ b/tensorflow/core/kernels/eigen_activations.h
@@ -0,0 +1,125 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_ACTIVATIONS_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_ACTIVATIONS_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace Eigen {
+
+/** scalar_sigmoid_fast_derivative_op
+ * \ingroup CXX11_NeuralNetworks_Module
+ * \brief Template functor to compute the fast derivative of a sigmoid
+ *
+ * Input should be the backpropagated gradient.
+ *
+ * \sa class CwiseUnaryOp, Cwise::sigmoid_fast_derivative()
+ */
+template <typename T>
+struct scalar_sigmoid_fast_derivative_op {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_sigmoid_fast_derivative_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& y) const {
+ const T one = T(1);
+ return (one - y) * y;
+ }
+
+ template <typename Packet>
+ inline Packet packetOp(const Packet& y) const {
+ const Packet one = internal::pset1<Packet>(1);
+ return internal::pmul(internal::psub(one, y), y);
+ }
+};
+
+namespace internal {
+template <typename T>
+struct functor_traits<scalar_sigmoid_fast_derivative_op<T> > {
+ enum {
+ Cost = NumTraits<T>::AddCost * 2 + NumTraits<T>::MulCost,
+ PacketAccess = packet_traits<T>::HasAdd && packet_traits<T>::HasMul &&
+ packet_traits<T>::HasNegate
+ };
+};
+} // namespace internal
+
+/** scalar_tanh_fast_derivative_op
+ * \ingroup CXX11_NeuralNetworks_Module
+ * \brief Template functor to compute the fast derivative of a tanh
+ *
+ * Input should be the backpropagated gradient.
+ *
+ * \sa class CwiseUnaryOp, Cwise::tanh_fast_derivative()
+ */
+template <typename T>
+struct scalar_tanh_fast_derivative_op {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_tanh_fast_derivative_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& y) const {
+ const T one = T(1);
+ return one - (y * y);
+ }
+
+ template <typename Packet>
+ inline Packet packetOp(const Packet& y) const {
+ const Packet one = internal::pset1<Packet>(1);
+ return internal::psub(one, internal::pmul(y, y));
+ }
+};
+
+namespace internal {
+template <typename T>
+struct functor_traits<scalar_tanh_fast_derivative_op<T> > {
+ enum {
+ Cost = NumTraits<T>::AddCost * 2 + NumTraits<T>::MulCost * 1,
+ PacketAccess = packet_traits<T>::HasAdd && packet_traits<T>::HasMul &&
+ packet_traits<T>::HasNegate
+ };
+};
+} // namespace internal
+
+/**
+ * \ingroup CXX11_NeuralNetworks_Module
+ * \brief Template functor to clip the the magnitude of the first scalar.
+ *
+ * \sa class CwiseBinaryOp, MatrixBase::Clip
+ */
+template <typename Scalar>
+struct scalar_clip_op {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_clip_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar
+ operator()(const Scalar& a, const Scalar& b) const {
+ return numext::mini(numext::maxi(a, -b), b);
+ }
+ template <typename Packet>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
+ packetOp(const Packet& a, const Packet& b) const {
+ return internal::pmin(internal::pmax(a, internal::pnegate(b)), b);
+ }
+};
+
+namespace internal {
+template <typename Scalar>
+struct functor_traits<scalar_clip_op<Scalar> > {
+ enum {
+ Cost = NumTraits<Scalar>::AddCost * 3,
+ PacketAccess = packet_traits<Scalar>::HasMax &&
+ packet_traits<Scalar>::HasMin &&
+ packet_traits<Scalar>::HasNegate
+ };
+};
+} // namespace internal
+
+} // end namespace Eigen
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_ACTIVATIONS_H_
diff --git a/tensorflow/core/kernels/eigen_activations_test.cc b/tensorflow/core/kernels/eigen_activations_test.cc
new file mode 100644
index 0000000000..390f6e8840
--- /dev/null
+++ b/tensorflow/core/kernels/eigen_activations_test.cc
@@ -0,0 +1,101 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/eigen_activations.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace Eigen {
+
+namespace {
+void EigenApprox(float a, float b) {
+ ASSERT_TRUE(std::abs(a - b) <= std::min(std::abs(a), std::abs(b)) * 1e-3);
+}
+}
+
+TEST(EigenBackwardSpatialConvolutionsTest, SigmoidFastDerivative) {
+ const ptrdiff_t depth = 3;
+ const ptrdiff_t batch = 10;
+ const ptrdiff_t rows = 32;
+ const ptrdiff_t cols = 48;
+
+ Tensor<float, 4> input(depth, rows, cols, batch);
+ input.setRandom();
+
+ Tensor<float, 4> result(depth, rows, cols, batch);
+ result = input.unaryExpr(scalar_sigmoid_fast_derivative_op<float>());
+
+ for (int b = 0; b < batch; ++b) {
+ for (int c = 0; c < cols; ++c) {
+ for (int r = 0; r < rows; ++r) {
+ for (int d = 0; d < depth; ++d) {
+ float val = input(d, r, c, b);
+ EigenApprox(result(d, r, c, b), (1 - val) * val);
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenBackwardSpatialConvolutionsTest, TanhFastDerivative) {
+ const ptrdiff_t depth = 3;
+ const ptrdiff_t batch = 10;
+ const ptrdiff_t rows = 32;
+ const ptrdiff_t cols = 48;
+
+ Tensor<float, 4> input(depth, rows, cols, batch);
+ input.setRandom();
+
+ Tensor<float, 4> result(depth, rows, cols, batch);
+ result = input.unaryExpr(scalar_tanh_fast_derivative_op<float>());
+
+ for (int b = 0; b < batch; ++b) {
+ for (int c = 0; c < cols; ++c) {
+ for (int r = 0; r < rows; ++r) {
+ for (int d = 0; d < depth; ++d) {
+ float val = input(d, r, c, b);
+ EigenApprox(result(d, r, c, b), 1 - (val * val));
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenBackwardSpatialConvolutionsTest, Clip) {
+ const ptrdiff_t depth = 3;
+ const ptrdiff_t batch = 10;
+ const ptrdiff_t rows = 32;
+ const ptrdiff_t cols = 48;
+
+ Tensor<float, 4> input(depth, rows, cols, batch);
+ input.setRandom();
+
+ Tensor<float, 4> result(depth, rows, cols, batch);
+ result = input.binaryExpr(input.constant(0.01), scalar_clip_op<float>());
+
+ for (int b = 0; b < batch; ++b) {
+ for (int c = 0; c < cols; ++c) {
+ for (int r = 0; r < rows; ++r) {
+ for (int d = 0; d < depth; ++d) {
+ float val = input(d, r, c, b);
+ EigenApprox(result(d, r, c, b),
+ (std::min)((std::max)(val, -0.01f), 0.01f));
+ }
+ }
+ }
+ }
+}
+
+} // namespace Eigen
diff --git a/tensorflow/core/kernels/eigen_attention.h b/tensorflow/core/kernels/eigen_attention.h
new file mode 100644
index 0000000000..e7bdda1693
--- /dev/null
+++ b/tensorflow/core/kernels/eigen_attention.h
@@ -0,0 +1,244 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_ATTENTION_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_ATTENTION_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace Eigen {
+
+/** ExtractGlimpses
+ * \ingroup CXX11_NeuralNetworks_Module
+ *
+ * \brief Extract glimpses from an input tensor.
+ *
+ * The input parameter is expected to be a col-major tensor with a rank of 4 (depth, x, y, and batch).
+ * The width and height parameters specify the extension of the returned glimpses.
+ * The offsets parameter specifies the x, y locations of the center of the glimpses relative to the center of the input image. The vector is expected to contain one IndexPair for each image in the batch dimension.
+ * The normalized boolean indicates if incoming coordinates are normalized so that 0.0 and 1.0 correspond to the minimum and maximum of each height and width dimension.
+ * The centered boolean indicates if incoming coordinates are centered relative to the image, in which case -1.0 and 1.0 correspond to minimum and maximum of each dimension while 0.0 corresponds to the center.
+ *
+ * The result can be assigned to a tensor of rank equal to that of the input. The result will be laid out in col-major order (depth, x, y, batch).
+ * The dimensions of the result will be equal to the dimensions of the input except for width and height which will be equal to the requested glimpse size.
+ */
+namespace {
+template <typename Index>
+struct GlimpseExtractionOp {
+ GlimpseExtractionOp(const Index width, const Index height,
+ const std::vector<IndexPair<float> >& offsets,
+ const bool normalized,
+ const bool centered,
+ const bool uniform_noise) :
+ width_(width), height_(height), offsets_(offsets),
+ normalized_(normalized), centered_(centered), uniform_noise_(uniform_noise) { }
+
+ template <typename Input>
+ DSizes<Index, 4> dimensions(const Input& input) const {
+ typedef typename internal::traits<Input>::Index IndexType;
+ typedef TensorRef<Tensor<typename internal::traits<Input>::Scalar, 4,
+ internal::traits<Input>::Layout, IndexType> > Ref;
+ Ref in(input);
+
+ DSizes<Index, 4> dims = in.dimensions();
+
+ dims[0] = in.dimension(0);
+ dims[1] = width_;
+ dims[2] = height_;
+ dims[3] = in.dimension(3);
+ return dims;
+ }
+
+ template <typename Input, typename Output, typename Device>
+ EIGEN_DEVICE_FUNC
+ void eval(const Input& input, Output& output, const Device& device) const
+ {
+ typedef typename internal::traits<Input>::Index IndexType;
+ typedef TensorRef<Tensor<typename internal::traits<Input>::Scalar, 4,
+ internal::traits<Input>::Layout, IndexType> > Ref;
+ Ref in(input);
+ const Index num_channels = in.dimension(0);
+ const Index input_width = in.dimension(1);
+ const Index input_height = in.dimension(2);
+ const Index batch_size = in.dimension(3);
+ eigen_assert(input_width > 0);
+ eigen_assert(input_height > 0);
+ internal::NormalRandomGenerator<float> gen;
+ internal::UniformRandomGenerator<float> unigen;
+
+ for (Index i = 0; i < batch_size; ++i) {
+ float x = offsets_[i].first, y = offsets_[i].second;
+
+ // Un-normalize coordinates back to pixel space if normalized.
+ if (normalized_) {
+ x *= input_width;
+ y *= input_height;
+ }
+ // Un-center if coordinates are centered on the image center.
+ if (centered_) {
+ x /= 2.0f;
+ y /= 2.0f;
+ x += input_width / 2.0f;
+ y += input_height / 2.0f;
+ }
+ // Remove half of the glimpse window.
+ x -= width_ / 2.0f;
+ y -= height_ / 2.0f;
+
+ const Index offset_x = (Index) x;
+ const Index offset_y = (Index) y;
+ Index glimpse_width = width_;
+ Index glimpse_height = height_;
+ bool partial_overlap = false;
+ DSizes<Index, 3> slice_offset(0, offset_x, offset_y);
+ DSizes<Index, 3> slice_extent(num_channels, width_, height_);
+ DSizes<Index, 3> base_offset(0, 0, 0);
+
+ if (offset_x < 0) {
+ slice_offset[1] = 0;
+ glimpse_width = (std::max<Index>)(0, width_ + offset_x);
+ slice_extent[1] = glimpse_width;
+ base_offset[1] = width_ - glimpse_width;
+ partial_overlap = true;
+ } else if (offset_x + width_ >= input_width) {
+ glimpse_width = (std::max<Index>)(0, input_width - offset_x);
+ slice_extent[1] = glimpse_width;
+ partial_overlap = true;
+ }
+ if (offset_y < 0) {
+ slice_offset[2] = 0;
+ glimpse_height = (std::max<Index>)(0, height_ + offset_y);
+ slice_extent[2] = glimpse_height;
+ base_offset[2] = height_ - glimpse_height;
+ partial_overlap = true;
+ } else if (offset_y + height_ >= input_height) {
+ glimpse_height = (std::max<Index>)(0, input_height - offset_y);
+ slice_extent[2] = glimpse_height;
+ partial_overlap = true;
+ }
+ slice_extent[1] = std::min<Index>(input_width, slice_extent[1]);
+ slice_extent[2] = std::min<Index>(input_height, slice_extent[2]);
+
+
+ if (partial_overlap) {
+
+ if (uniform_noise_) {
+ // Initialize the glimpse with uniform noise.
+ typedef typename internal::remove_const<
+ typename internal::traits<Input>::Scalar>::type Scalar;
+ TensorFixedSize<Scalar, Sizes<> > mini;
+ mini.device(device) = input.template chip<3>(i).minimum();
+ TensorFixedSize<float, Sizes<> > range;
+ range.device(device) = (input.template chip<3>(i).maximum() - mini)
+ .template cast<float>();
+
+ DSizes<Index, 3> glimpse_size(num_channels, width_, height_);
+ TensorMap<Tensor<float, 3> > tmp(NULL, glimpse_size);
+ output.template chip<3>(i).device(device) =
+ mini.reshape(Sizes<1, 1, 1>()).broadcast(glimpse_size) +
+ (tmp.random(unigen) *
+ range.reshape(Sizes<1, 1, 1>()).broadcast(glimpse_size))
+ .template cast<Scalar>();
+ } else {
+ // Initialize the glimpse with white noise: compute the mean and sigma
+ // of each channel, and use them to shape the gaussian.
+ DSizes<Index, 2> glimpse_size(width_, height_);
+ DSizes<Index, 2> input_size(input_width, input_height);
+ typedef typename internal::remove_const<
+ typename internal::traits<Input>::Scalar>::type Scalar;
+
+ for (int j = 0; j < num_channels; ++j) {
+ TensorFixedSize<Scalar, Sizes<> > mean;
+ mean.device(device) = input.template chip<3>(i)
+ .template chip<0>(j)
+ .template cast<float>()
+ .mean();
+ TensorFixedSize<float, Sizes<> > sigma;
+ sigma.device(device) =
+ (input.template chip<3>(i)
+ .template chip<0>(j)
+ .template cast<float>() -
+ mean.reshape(Sizes<1, 1>()).broadcast(input_size))
+ .square()
+ .mean()
+ .sqrt();
+ TensorFixedSize<Scalar, Sizes<> > mini;
+ mini.device(device) =
+ input.template chip<3>(i).template chip<0>(j).minimum();
+ TensorFixedSize<float, Sizes<> > maxi;
+ maxi.device(device) =
+ input.template chip<3>(i).template chip<0>(j).maximum();
+
+ TensorMap<Tensor<float, 2> > tmp(NULL, glimpse_size);
+ output.template chip<3>(i).template chip<0>(j).device(device) =
+ (mean.reshape(Sizes<1, 1>()).broadcast(glimpse_size) +
+ (tmp.random(gen) *
+ sigma.reshape(Sizes<1, 1>()).broadcast(glimpse_size))
+ .template cast<Scalar>())
+ .cwiseMin(
+ maxi.reshape(Sizes<1, 1>()).broadcast(glimpse_size))
+ .cwiseMax(
+ mini.reshape(Sizes<1, 1>()).broadcast(glimpse_size));
+ }
+ }
+
+ // Copy the part of the glimpse that cover the input image if any.
+ if (glimpse_width == 0 || glimpse_height == 0) {
+ continue;
+ }
+ output.template chip<3>(i)
+ .slice(base_offset, slice_extent)
+ .device(device) =
+ input.template chip<3>(i).slice(slice_offset, slice_extent);
+ } else {
+ output.template chip<3>(i).device(device) =
+ input.template chip<3>(i).slice(slice_offset, slice_extent);
+ }
+ }
+ }
+
+ private:
+ const Index width_;
+ const Index height_;
+ const std::vector<IndexPair<float> > offsets_;
+ const bool normalized_;
+ const bool centered_;
+ const bool uniform_noise_;
+};
+}
+
+
+template <typename Input>
+EIGEN_ALWAYS_INLINE
+static const TensorCustomUnaryOp<const GlimpseExtractionOp<typename internal::traits<Input>::Index>, const Input>
+ExtractGlimpses(const Input& input,
+ const typename internal::traits<Input>::Index width,
+ const typename internal::traits<Input>::Index height,
+ const std::vector<IndexPair<float> >& offsets,
+ const bool normalized = true, const bool centered = true,
+ const bool uniform_noise = true)
+{
+ EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout == ColMajor, YOU_MADE_A_PROGRAMMING_MISTAKE);
+ EIGEN_STATIC_ASSERT(internal::traits<Input>::NumDimensions == 4, YOU_MADE_A_PROGRAMMING_MISTAKE);
+
+ typedef typename internal::traits<Input>::Index Index;
+ const GlimpseExtractionOp<Index> op(width, height, offsets, normalized,
+ centered, uniform_noise);
+ return input.customOp(op);
+}
+
+} // end namespace Eigen
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_ATTENTION_H_
diff --git a/tensorflow/core/kernels/eigen_attention_test.cc b/tensorflow/core/kernels/eigen_attention_test.cc
new file mode 100644
index 0000000000..7d5e0b71b5
--- /dev/null
+++ b/tensorflow/core/kernels/eigen_attention_test.cc
@@ -0,0 +1,107 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/eigen_attention.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace Eigen {
+
+namespace {
+void EigenApprox(float a, float b) {
+ ASSERT_TRUE(std::abs(a - b) <= std::min(std::abs(a), std::abs(b)) * 1e-3);
+}
+}
+
+TEST(EigenAttentionTest, Simple) {
+ const ptrdiff_t depth = 3;
+ const ptrdiff_t batch = 10;
+ const ptrdiff_t rows = 32;
+ const ptrdiff_t cols = 48;
+ const ptrdiff_t glimpse_rows = 8;
+ const ptrdiff_t glimpse_cols = 6;
+
+ Tensor<float, 4> input(depth, rows, cols, batch);
+ input.setRandom();
+
+ std::vector<IndexPair<float>> offsets;
+ offsets.resize(batch);
+ for (int i = 0; i < batch; ++i) {
+ offsets[i].first = (-5 + i) / 10.0f;
+ offsets[i].second = (5 - i) / 10.0f;
+ }
+
+ Tensor<float, 4> result(depth, glimpse_rows, glimpse_cols, batch);
+ result = ExtractGlimpses(input, glimpse_rows, glimpse_cols, offsets);
+
+ for (int b = 0; b < batch; ++b) {
+ for (int c = 0; c < glimpse_cols; ++c) {
+ ptrdiff_t source_c =
+ c + ((1.0f + offsets[b].second) * cols - glimpse_cols) / 2;
+ for (int r = 0; r < glimpse_rows; ++r) {
+ ptrdiff_t source_r =
+ r + ((1.0f + offsets[b].first) * rows - glimpse_rows) / 2;
+ for (int d = 0; d < depth; ++d) {
+ EigenApprox(result(d, r, c, b), input(d, source_r, source_c, b));
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenAttentionTest, OutOfBoundsGlimpse) {
+ const ptrdiff_t depth = 3;
+ const ptrdiff_t batch = 10;
+ const ptrdiff_t rows = 32;
+ const ptrdiff_t cols = 48;
+ const ptrdiff_t glimpse_rows = 8;
+ const ptrdiff_t glimpse_cols = 6;
+
+ Tensor<float, 4> input(depth, rows, cols, batch);
+ input.setRandom();
+
+ std::vector<IndexPair<float>> offsets;
+ offsets.resize(batch);
+ for (int i = 0; i < batch; ++i) {
+ offsets[i].first = (-5 + i) / 2.0f;
+ offsets[i].second = (5 - i) / 2.0f;
+ }
+
+ Tensor<float, 4> result(depth, glimpse_rows, glimpse_cols, batch);
+ result = ExtractGlimpses(input, glimpse_rows, glimpse_cols, offsets);
+
+ for (int b = 0; b < batch; ++b) {
+ for (int c = 0; c < glimpse_cols; ++c) {
+ ptrdiff_t source_c =
+ c + ((1.0f + offsets[b].second) * cols - glimpse_cols) / 2;
+ if (source_c < glimpse_cols / 2 || source_c >= cols - glimpse_cols / 2) {
+ continue;
+ }
+ for (int r = 0; r < glimpse_rows; ++r) {
+ ptrdiff_t source_r =
+ r + ((1.0f + offsets[b].first) * rows - glimpse_rows) / 2;
+ if (source_r < glimpse_rows / 2 ||
+ source_r >= rows - glimpse_rows / 2) {
+ continue;
+ }
+ for (int d = 0; d < depth; ++d) {
+ EigenApprox(result(d, r, c, b), input(d, source_r, source_c, b));
+ }
+ }
+ }
+ }
+}
+
+} // namespace Eigen
diff --git a/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h b/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h
new file mode 100644
index 0000000000..937a0c5acb
--- /dev/null
+++ b/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h
@@ -0,0 +1,539 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_CUBOID_CONVOLUTIONS_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_CUBOID_CONVOLUTIONS_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/kernels/eigen_patch_3d.h"
+
+namespace Eigen {
+
+/** CuboidConvolutionBackwardInput
+ * \ingroup CXX11_NeuralNetworks_Module
+ *
+ * \brief Computes the backprop for the input of a 3D convolution.
+ *
+ * The output_backward parameter is expected to be a tensor with a rank of 4 or more (channels, depth, height, width, and optionally others)
+ * The kernel parameter is expected to be a 5D tensor (filters, channels, kernel_depth, kernel_height, kernel_width)
+ * output_backward and kernel have to be in the same layout.
+ *
+ * The dimensions of the result will be filters, depth, height, width (and others if applicable).
+ *
+ * It is possible to swap the order of the depth, width and height dimensions provided that the same order is used in the input, the kernel, and the output.
+ *
+ * All dimension orders above are given for col-major, and should be reversed for row-major.
+ */
+
+template <typename OutputBackward, typename Kernel>
+EIGEN_ALWAYS_INLINE static const typename internal::conditional<
+ internal::traits<OutputBackward>::Layout == ColMajor,
+ TensorReshapingOp<
+ const DSizes<typename internal::traits<OutputBackward>::Index,
+ internal::traits<OutputBackward>::NumDimensions>,
+ const TensorContractionOp<
+ const array< IndexPair<typename internal::traits<OutputBackward>::Index>, 2>,
+ const TensorReshapingOp<
+ const DSizes< typename internal::traits<OutputBackward>::Index, 3>,
+ const TensorReverseOp<const array<bool, 5>, const Kernel>
+ >,
+ const TensorReshapingOp<
+ const DSizes< typename internal::traits<OutputBackward>::Index, 3>,
+ const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic, const OutputBackward>
+ >
+ >
+ >,
+ TensorReshapingOp<
+ const DSizes<typename internal::traits<OutputBackward>::Index,
+ internal::traits<OutputBackward>::NumDimensions>,
+ const TensorContractionOp<
+ const array< IndexPair<typename internal::traits<OutputBackward>::Index>, 2>,
+ const TensorReshapingOp<
+ const DSizes< typename internal::traits<OutputBackward>::Index, 3>,
+ const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic, const OutputBackward>
+ >,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<OutputBackward>::Index, 3>,
+ const TensorReverseOp<const array<bool, 5>, const Kernel>
+ >
+ >
+ >
+>::type
+CuboidConvolutionBackwardInput(
+ const Kernel& kernel, const OutputBackward& output_backward,
+ typename internal::traits<OutputBackward>::Index inputPlanes,
+ typename internal::traits<OutputBackward>::Index inputRows,
+ typename internal::traits<OutputBackward>::Index inputCols,
+ const DenseIndex stridePlanes = 1, const DenseIndex strideRows = 1,
+ const DenseIndex strideCols = 1) {
+ typedef typename internal::traits<OutputBackward>::Index TensorIndex;
+ const TensorRef<const Tensor<typename internal::traits<Kernel>::Scalar, internal::traits<Kernel>::NumDimensions, internal::traits<Kernel>::Layout, TensorIndex> > kern(kernel);
+ const TensorRef<const Tensor<typename internal::traits<OutputBackward>::Scalar, internal::traits<OutputBackward>::NumDimensions, internal::traits<OutputBackward>::Layout, TensorIndex> > out(output_backward);
+
+ EIGEN_STATIC_ASSERT(internal::traits<Kernel>::Layout == internal::traits<OutputBackward>::Layout, YOU_MADE_A_PROGRAMMING_MISTAKE);
+
+ static const bool isColMajor = (internal::traits<OutputBackward>::Layout == ColMajor);
+
+ static const int NumDims = internal::traits<OutputBackward>::NumDimensions;
+
+ // Number of filters to apply. This is the same as the output depth of the result
+ const TensorIndex kernelFilters = isColMajor ? kern.dimensions()[0] : kern.dimensions()[4];
+ // Number of channels. This is the same as the input depth.
+ const TensorIndex kernelChannels = isColMajor ? kern.dimensions()[1] : kern.dimensions()[3];
+ const TensorIndex kernelPlanes = isColMajor ? kern.dimensions()[2] : kern.dimensions()[2];
+ const TensorIndex kernelRows = isColMajor ? kern.dimensions()[3] : kern.dimensions()[1];
+ const TensorIndex kernelCols = isColMajor ? kern.dimensions()[4] : kern.dimensions()[0];
+
+ const TensorIndex outputPlanes = isColMajor ? out.dimensions()[1] : out.dimensions()[NumDims - 2];
+ const TensorIndex outputRows = isColMajor ? out.dimensions()[2] : out.dimensions()[NumDims - 3];
+ const TensorIndex outputCols = isColMajor ? out.dimensions()[3] : out.dimensions()[NumDims - 4];
+
+ TensorIndex forward_pad_z, forward_pad_y, forward_pad_x;
+ const TensorIndex size_z = ceil(inputPlanes / static_cast<float>(stridePlanes));
+ const TensorIndex size_y = ceil(inputRows / static_cast<float>(strideRows));
+ const TensorIndex size_x = ceil(inputCols / static_cast<float>(strideCols));
+
+ // Infer padding type.
+ if (size_z == outputPlanes && size_y == outputRows && size_x == outputCols) {
+ // SAME padding.
+ const TensorIndex dz = size_z * stridePlanes + kernelPlanes - 1 - inputPlanes;
+ const TensorIndex dy = size_y * strideRows + kernelRows - 1 - inputRows;
+ const TensorIndex dx = size_x * strideCols + kernelCols - 1 - inputCols;
+
+ forward_pad_z = dz - dz / 2;
+ forward_pad_y = dy - dy / 2;
+ forward_pad_x = dx - dx / 2;
+ } else {
+ // VALID padding.
+ forward_pad_z = 0;
+ forward_pad_y = 0;
+ forward_pad_x = 0;
+ }
+ const TensorIndex padding_ztop = kernelPlanes - 1 - forward_pad_z;
+ const TensorIndex padding_top = kernelRows - 1 - forward_pad_y;
+ const TensorIndex padding_left = kernelCols - 1 - forward_pad_x;
+
+ const TensorIndex padding_zbottom = inputPlanes + kernelPlanes - 1 - (outputPlanes - 1) * stridePlanes - 1 - padding_ztop;
+ const TensorIndex padding_bottom = inputRows + kernelRows - 1 - (outputRows - 1) * strideRows - 1 - padding_top;
+ const TensorIndex padding_right = inputCols + kernelCols - 1 - (outputCols - 1) * strideCols - 1 - padding_left;
+
+ eigen_assert(padding_ztop >= 0);
+ eigen_assert(padding_zbottom >= 0);
+ eigen_assert(padding_top >= 0);
+ eigen_assert(padding_left >= 0);
+ eigen_assert(padding_bottom >= 0);
+ eigen_assert(padding_right >= 0);
+
+ // The kernel has dimensions filters X channels X patch_planes X patch_rows X patch_cols.
+ // We need to reverse the kernel along the spatial dimensions.
+ array<bool, 5> kernel_reverse;
+ if (isColMajor) {
+ kernel_reverse[0] = false;
+ kernel_reverse[1] = false;
+ kernel_reverse[2] = true;
+ kernel_reverse[3] = true;
+ kernel_reverse[4] = true;
+ } else {
+ kernel_reverse[0] = true;
+ kernel_reverse[1] = true;
+ kernel_reverse[2] = true;
+ kernel_reverse[3] = false;
+ kernel_reverse[4] = false;
+ }
+
+ DSizes<TensorIndex, 3> kernel_dims;
+ if (isColMajor) {
+ kernel_dims[0] = kernelFilters;
+ kernel_dims[1] = kernelChannels;
+ kernel_dims[2] = kernelRows * kernelCols * kernelPlanes;
+ } else {
+ kernel_dims[0] = kernelRows * kernelCols * kernelPlanes;
+ kernel_dims[1] = kernelChannels;
+ kernel_dims[2] = kernelFilters;
+ }
+
+ // The output_backward has dimensions out_depth X out_planes X out_rows X out_cols X OTHERS
+ // When we extract the image patches from output_backward, it will have dimensions:
+ // out_depth X (patch_planes * patch_rows * patch_cols) X (input_planes * input_rows * input_cols * OTHERS)
+ DSizes<TensorIndex, 3> pre_contract_dims;
+ if (isColMajor) {
+ pre_contract_dims[0] = kernelFilters;
+ pre_contract_dims[1] = kernelRows * kernelCols * kernelPlanes;
+ pre_contract_dims[2] = inputRows * inputCols * inputPlanes;
+ for (int i = 4; i < NumDims; ++i) {
+ pre_contract_dims[2] *= out.dimension(i);
+ }
+ } else {
+ pre_contract_dims[2] = kernelFilters;
+ pre_contract_dims[1] = kernelRows * kernelCols * kernelPlanes;
+ pre_contract_dims[0] = inputRows * inputCols * inputPlanes;
+ for (int i = 0; i < NumDims - 4; ++i) {
+ pre_contract_dims[0] *= out.dimension(i);
+ }
+ }
+
+ // We will contract along dimensions (0, 2) in kernel and (0, 1) in
+ // output_backward, if this is col-major, and
+ // dimensions (0, 2) in kernel and (1, 2) in output_backward, if this row-major.
+ array<IndexPair<TensorIndex>, 2> contract_dims;
+ if (isColMajor) {
+ // col-major: kernel.contract(output.patches)
+ contract_dims[0] = IndexPair<TensorIndex>(0, 0);
+ contract_dims[1] = IndexPair<TensorIndex>(2, 1);
+ } else {
+ // row-major: output.patches.contract(kernel)
+ contract_dims[0] = IndexPair<TensorIndex>(1, 0);
+ contract_dims[1] = IndexPair<TensorIndex>(2, 2);
+ }
+
+ // Post contraction, the dimensions of the input_backprop is
+ // channels X input_planes X input_rows X input_cols X OTHERS
+ DSizes<TensorIndex, NumDims> post_contract_dims;
+ if (isColMajor) {
+ post_contract_dims[0] = kernelChannels;
+ post_contract_dims[1] = inputPlanes;
+ post_contract_dims[2] = inputRows;
+ post_contract_dims[3] = inputCols;
+ for (int i = 4; i < NumDims; ++i) {
+ post_contract_dims[i] = out.dimension(i);
+ }
+ } else {
+ post_contract_dims[NumDims - 1] = kernelChannels;
+ post_contract_dims[NumDims - 2] = inputPlanes;
+ post_contract_dims[NumDims - 3] = inputRows;
+ post_contract_dims[NumDims - 4] = inputCols;
+ for (int i = 0; i < NumDims - 4; ++i) {
+ post_contract_dims[i] = out.dimension(i);
+ }
+ }
+
+ DSizes<TensorIndex, NumDims> strides;
+ for (int i = 0; i < NumDims; i++) {
+ strides[i] = 1;
+ }
+ if (isColMajor) {
+ strides[1] = stridePlanes;
+ strides[2] = strideRows;
+ strides[3] = strideCols;
+ } else {
+ strides[NumDims - 2] = stridePlanes;
+ strides[NumDims - 3] = strideRows;
+ strides[NumDims - 4] = strideCols;
+ }
+
+ return choose(
+ Cond<internal::traits<OutputBackward>::Layout == ColMajor>(),
+ kernel.reverse(kernel_reverse)
+ .reshape(kernel_dims)
+ .contract(
+ output_backward.extract_volume_patches(kernelPlanes, kernelRows, kernelCols,
+ 1, 1, 1, stridePlanes, strideRows, strideCols,
+ padding_ztop, padding_zbottom,
+ padding_top, padding_bottom,
+ padding_left, padding_right)
+ .reshape(pre_contract_dims),
+ contract_dims)
+ .reshape(post_contract_dims),
+ output_backward.extract_volume_patches(kernelPlanes, kernelRows, kernelCols,
+ 1, 1, 1, stridePlanes, strideRows, strideCols,
+ padding_ztop, padding_zbottom,
+ padding_top, padding_bottom,
+ padding_left, padding_right)
+ .reshape(pre_contract_dims)
+ .contract(kernel.reverse(kernel_reverse).reshape(kernel_dims),
+ contract_dims)
+ .reshape(post_contract_dims));
+}
+
+
+/** CuboidConvolutionBackwardKernel
+ * \ingroup CXX11_NeuralNetworks_Module
+ *
+ * \brief Computes the backprop for the filter of a 3D convolution.
+ *
+ * The output_backward parameter is expected to be a tensor with a rank of 4 or more (channels, depth, height, width, and optionally others)
+ * The kernel parameter is expected to be a 4D tensor (filters, channels, kernel_depth, kernel_height, kernel_width)
+ * output_backward and kernel have to be in the same layout.
+ *
+ * The dimensions of the result will be filters, depth, height, width (and others if applicable).
+ *
+ * It is possible to swap the order of the depth, width and height dimensions provided that the same order is used in the input, the kernel, and the output.
+ *
+ * All dimension orders above are given for col-major, and should be reversed for row-major.
+ */
+template <typename OutputBackward, typename Input>
+EIGEN_ALWAYS_INLINE static const typename internal::conditional<
+ internal::traits<OutputBackward>::Layout == ColMajor,
+ const TensorShufflingOp<
+ const array<typename internal::traits<OutputBackward>::Index, 5>,
+ const TensorReverseOp<
+ const array<bool, 5>,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<OutputBackward>::Index, 5>,
+ const TensorContractionOp<
+ const array< IndexPair<typename internal::traits<Input>::Index>, 2>,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index, 3>,
+ const Input>,
+ const TensorReshapingOp<
+ const DSizes< typename internal::traits<OutputBackward>::Index, 4>,
+ const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic, const OutputBackward>
+ >
+ >
+ >
+ >
+ >,
+ const TensorShufflingOp<
+ const array<typename internal::traits<OutputBackward>::Index, 5>,
+ const TensorReverseOp<
+ const array<bool, 5>,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<OutputBackward>::Index, 5>,
+ const TensorContractionOp<
+ const array< IndexPair<typename internal::traits<Input>::Index>, 2>,
+ const TensorReshapingOp<
+ const DSizes< typename internal::traits<OutputBackward>::Index, 4>,
+ const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic, const OutputBackward>
+ >,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index, 3>,
+ const Input
+ >
+ >
+ >
+ >
+ >
+>::type
+CuboidConvolutionBackwardKernel(
+ const Input& input, const OutputBackward& output_backward,
+ typename internal::traits<Input>::Index kernelPlanes,
+ typename internal::traits<Input>::Index kernelRows,
+ typename internal::traits<Input>::Index kernelCols,
+ const DenseIndex stridePlanes = 1,
+ const DenseIndex strideRows = 1,
+ const DenseIndex strideCols = 1) {
+ typedef typename internal::traits<Input>::Index TensorIndex;
+ TensorRef<Tensor<typename internal::traits<Input>::Scalar, internal::traits<Input>::NumDimensions, internal::traits<Input>::Layout, TensorIndex> > in(input);
+ TensorRef<Tensor<typename internal::traits<OutputBackward>::Scalar, internal::traits<OutputBackward>::NumDimensions, internal::traits<OutputBackward>::Layout, TensorIndex> > out(output_backward);
+
+ EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout == internal::traits<OutputBackward>::Layout, YOU_MADE_A_PROGRAMMING_MISTAKE);
+
+ static const bool isColMajor = (internal::traits<Input>::Layout == ColMajor);
+
+ static const int NumDims = internal::traits<Input>::NumDimensions;
+ EIGEN_STATIC_ASSERT(internal::traits<Input>::NumDimensions == internal::traits<OutputBackward>::NumDimensions, YOU_MADE_A_PROGRAMMING_MISTAKE);
+
+ const TensorIndex inputPlanes = isColMajor ? in.dimension(1) : in.dimension(NumDims - 2);
+ const TensorIndex inputRows = isColMajor ? in.dimension(2) : in.dimension(NumDims - 3);
+ const TensorIndex inputCols = isColMajor ? in.dimension(3) : in.dimension(NumDims - 4);
+
+ const TensorIndex outputPlanes = isColMajor ? out.dimension(1) : out.dimension(NumDims - 2);
+ const TensorIndex outputRows = isColMajor ? out.dimension(2) : out.dimension(NumDims - 3);
+ const TensorIndex outputCols = isColMajor ? out.dimension(3) : out.dimension(NumDims - 4);
+
+ const TensorIndex kernelFilters = isColMajor ? out.dimension(0) : out.dimension(NumDims - 1);
+ const TensorIndex kernelChannels = isColMajor ? in.dimension(0) : in.dimension(NumDims - 1);
+
+ TensorIndex forward_pad_z, forward_pad_y, forward_pad_x;
+ const TensorIndex size_z = ceil(inputPlanes / static_cast<float>(stridePlanes));
+ const TensorIndex size_y = ceil(inputRows / static_cast<float>(strideRows));
+ const TensorIndex size_x = ceil(inputCols / static_cast<float>(strideCols));
+
+ // Infer padding type.
+ if (size_z == outputPlanes && size_y == outputRows && size_x == outputCols) {
+ // SAME padding.
+ const TensorIndex dz = size_z * stridePlanes + kernelPlanes - 1 - inputPlanes;
+ const TensorIndex dy = size_y * strideRows + kernelRows - 1 - inputRows;
+ const TensorIndex dx = size_x * strideCols + kernelCols - 1 - inputCols;
+
+ forward_pad_z = dz - dz / 2;
+ forward_pad_y = dy - dy / 2;
+ forward_pad_x = dx - dx / 2;
+ } else {
+ // VALID padding.
+ forward_pad_z = 0;
+ forward_pad_y = 0;
+ forward_pad_x = 0;
+ }
+
+ const TensorIndex padding_ztop = kernelPlanes - 1 - forward_pad_z;
+ const TensorIndex padding_top = kernelRows - 1 - forward_pad_y;
+ const TensorIndex padding_left = kernelCols - 1 - forward_pad_x;
+
+ const TensorIndex padding_zbottom = inputPlanes + kernelPlanes - 1 - (outputPlanes - 1) * stridePlanes - 1 - padding_ztop;
+ const TensorIndex padding_bottom = inputRows + kernelRows - 1 - (outputRows - 1) * strideRows - 1 - padding_top;
+ const TensorIndex padding_right = inputCols + kernelCols - 1 - (outputCols - 1) * strideCols - 1 - padding_left;
+
+ eigen_assert(padding_ztop >= 0);
+ eigen_assert(padding_zbottom >= 0);
+ eigen_assert(padding_top >= 0);
+ eigen_assert(padding_left >= 0);
+ eigen_assert(padding_bottom >= 0);
+ eigen_assert(padding_right >= 0);
+
+ // The output_backward has dimensions out_depth X out_plaens X out_rows X out_cols X OTHERS
+ // When we extract the image patches from output_backward (with input as the
+ // kernel), it will have dimensions
+ // (out_depth) X (input_planes * input_rows * input_cols) X (kernel_planes * kernel_rows * kernel_cols) X OTHERS
+ DSizes<TensorIndex, 4> pre_contract_dims;
+ if (isColMajor) {
+ pre_contract_dims[0] = kernelFilters;
+ pre_contract_dims[1] = inputRows * inputCols * inputPlanes;
+ pre_contract_dims[2] = kernelRows * kernelCols * kernelPlanes;
+ pre_contract_dims[3] = 1;
+ for (int i = 4; i < NumDims; ++i) {
+ pre_contract_dims[3] *= out.dimension(i);
+ }
+ } else {
+ pre_contract_dims[3] = kernelFilters;
+ pre_contract_dims[2] = inputRows * inputCols * inputPlanes;
+ pre_contract_dims[1] = kernelRows * kernelCols * kernelPlanes;
+ pre_contract_dims[0] = 1;
+ for (int i = 0; i < NumDims - 4; ++i) {
+ pre_contract_dims[0] *= out.dimension(i);
+ }
+ }
+
+ // The input has dimensions in_depth X (input_planes * input_rows * input_cols) X OTHERS
+ DSizes<TensorIndex, 3> input_dims;
+ if (isColMajor) {
+ input_dims[0] = kernelChannels;
+ input_dims[1] = inputRows * inputCols * inputPlanes;
+ input_dims[2] = 1;
+ for (int i = 4; i < NumDims; ++i) {
+ input_dims[2] *= in.dimension(i);
+ }
+ eigen_assert(input_dims[2] == pre_contract_dims[3]);
+ } else {
+ input_dims[2] = kernelChannels;
+ input_dims[1] = inputRows * inputCols * inputPlanes;
+ input_dims[0] = 1;
+ for (int i = 0; i < NumDims - 4; ++i) {
+ input_dims[0] *= in.dimension(i);
+ }
+ eigen_assert(input_dims[0] == pre_contract_dims[0]);
+ }
+
+ // We will contract along dimensions (1, 2) in in and (1, 3) in out, if
+ // this is col-major.
+ // For row-major, it's dimensions (0, 1) in in and (0, 2) in out.
+ array<IndexPair<TensorIndex>, 2> contract_dims;
+ if (isColMajor) {
+ // col-major: in.contract(output.patches)
+ contract_dims[0] = IndexPair<TensorIndex>(1, 1);
+ contract_dims[1] = IndexPair<TensorIndex>(2, 3);
+ } else {
+ // row-major: output.patches.contract(in)
+ contract_dims[0] = IndexPair<TensorIndex>(0, 0);
+ contract_dims[1] = IndexPair<TensorIndex>(2, 1);
+ }
+
+ // After the contraction, the kernel will have dimension
+ // in_depth X out_depth X kernel_patches X kernel_rows X kernel_cols
+ // We will need to shuffle the first two dimensions and reverse the spatial dimensions.
+ // The end shape is:
+ // out_depth X in_shape X kernel_planes X kernel_rows X kernel_cols
+
+ // This is the shape of the kernel *before* the shuffling.
+ DSizes<TensorIndex, 5> kernel_dims;
+ if (isColMajor) {
+ kernel_dims[0] = kernelChannels;
+ kernel_dims[1] = kernelFilters;
+ kernel_dims[2] = kernelPlanes;
+ kernel_dims[3] = kernelRows;
+ kernel_dims[4] = kernelCols;
+ } else {
+ kernel_dims[0] = kernelCols;
+ kernel_dims[1] = kernelRows;
+ kernel_dims[2] = kernelPlanes;
+ kernel_dims[3] = kernelFilters;
+ kernel_dims[4] = kernelChannels;
+ }
+
+ // Flip filters and channels.
+ array<TensorIndex, 5> kernel_shuffle;
+ if (isColMajor) {
+ kernel_shuffle[0] = 1;
+ kernel_shuffle[1] = 0;
+ kernel_shuffle[2] = 2;
+ kernel_shuffle[3] = 3;
+ kernel_shuffle[4] = 4;
+ } else {
+ kernel_shuffle[0] = 0;
+ kernel_shuffle[1] = 1;
+ kernel_shuffle[2] = 2;
+ kernel_shuffle[3] = 4;
+ kernel_shuffle[4] = 3;
+ }
+
+ // Reverse the spatial dimensions.
+ array<bool, 5> kernel_reverse;
+ if (isColMajor) {
+ kernel_reverse[0] = false;
+ kernel_reverse[1] = false;
+ kernel_reverse[2] = true;
+ kernel_reverse[3] = true;
+ kernel_reverse[4] = true;
+ } else {
+ kernel_reverse[0] = true;
+ kernel_reverse[1] = true;
+ kernel_reverse[2] = true;
+ kernel_reverse[3] = false;
+ kernel_reverse[4] = false;
+ }
+
+ DSizes<TensorIndex, NumDims> strides;
+ for (int i = 0; i < NumDims; i++) {
+ strides[i] = 1;
+ }
+ if (isColMajor) {
+ strides[1] = stridePlanes;
+ strides[2] = strideRows;
+ strides[3] = strideCols;
+ } else {
+ strides[NumDims - 2] = stridePlanes;
+ strides[NumDims - 3] = strideRows;
+ strides[NumDims - 4] = strideCols;
+ }
+ return choose(
+ Cond<internal::traits<Input>::Layout == ColMajor>(),
+ input.reshape(input_dims)
+ .contract(
+ output_backward.extract_volume_patches(
+ inputPlanes, inputRows, inputCols, 1,
+ 1, 1, stridePlanes, strideRows, strideCols,
+
+ padding_ztop, padding_zbottom, padding_top,
+ padding_bottom, padding_left, padding_right)
+ .reshape(pre_contract_dims),
+ contract_dims)
+ .reshape(kernel_dims)
+ .reverse(kernel_reverse)
+ .shuffle(kernel_shuffle),
+ output_backward.extract_volume_patches(
+ inputPlanes, inputRows, inputCols, 1, 1, 1,
+ stridePlanes, strideRows, strideCols, padding_ztop,
+ padding_zbottom, padding_top, padding_bottom,
+ padding_left, padding_right)
+ .reshape(pre_contract_dims)
+ .contract(input.reshape(input_dims), contract_dims)
+ .reshape(kernel_dims)
+ .reverse(kernel_reverse)
+ .shuffle(kernel_shuffle));
+}
+
+} // end namespace Eigen
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_CUBOID_CONVOLUTIONS_H_
diff --git a/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h b/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h
new file mode 100644
index 0000000000..7a5a94bb6f
--- /dev/null
+++ b/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h
@@ -0,0 +1,359 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_SPATIAL_CONVOLUTIONS_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_SPATIAL_CONVOLUTIONS_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace Eigen {
+
+/** SpatialConvolutionBackwardInput
+ * \ingroup CXX11_NeuralNetworks_Module
+ *
+ * \brief Computes the backprop for the input of a 2D convolution.
+ *
+ * The output_backward parameter is expected to be a tensor with a rank of 3 or more (channels, height, width, and optionally others)
+ * The kernel parameter is expected to be a 4D tensor (filters, channels, kernel_height, kernel_width)
+ * The output_backward and the kernel must both be in col-major layout. The result will also be in col-major layout.
+ *
+ * If in_stride > 1, then applies convolution with holes (aka atrous convolution), sampling every in_stride input pixels.
+ *
+ * The result can be assigned to a tensor of rank equal to the rank of the output_backward. The dimensions of the result will be filters, height, width (and others if applicable).
+ *
+ * It is possible to swap the order of the width and height dimensions provided that the same order is used in the input, the kernel, and the output.
+ *
+ */
+
+template <typename OutputBackward, typename Kernel>
+EIGEN_ALWAYS_INLINE
+static const typename internal::conditional<
+ internal::traits<OutputBackward>::Layout == ColMajor,
+ TensorReshapingOp<const DSizes<typename internal::traits<OutputBackward>::Index, internal::traits<OutputBackward>::NumDimensions>, const TensorContractionOp<const array<IndexPair<typename internal::traits<OutputBackward>::Index>, 2>, const TensorReshapingOp<const DSizes<typename internal::traits<OutputBackward>::Index, 3>, const TensorReverseOp<const array<bool, 4>, const Kernel> >, const TensorReshapingOp<const DSizes<typename internal::traits<OutputBackward>::Index, 3>, const TensorImagePatchOp<Dynamic, Dynamic, const OutputBackward> > > >,
+ TensorReshapingOp<const DSizes<typename internal::traits<OutputBackward>::Index, internal::traits<OutputBackward>::NumDimensions>, const TensorContractionOp<const array<IndexPair<typename internal::traits<OutputBackward>::Index>, 2>, const TensorReshapingOp<const DSizes<typename internal::traits<OutputBackward>::Index, 3>, const TensorImagePatchOp<Dynamic, Dynamic, const OutputBackward> >, const TensorReshapingOp<const DSizes<typename internal::traits<OutputBackward>::Index, 3>, const TensorReverseOp<const array<bool, 4>, const Kernel> > > > >::type
+SpatialConvolutionBackwardInput(const Kernel& kernel, const OutputBackward& output_backward, typename internal::traits<OutputBackward>::Index inputRows, typename internal::traits<OutputBackward>::Index inputCols, const DenseIndex stride = 1, const DenseIndex in_stride = 1) {
+
+ typedef typename internal::traits<OutputBackward>::Index TensorIndex;
+ TensorRef<Tensor<typename internal::traits<Kernel>::Scalar, internal::traits<Kernel>::NumDimensions, internal::traits<Kernel>::Layout, TensorIndex> > kern(kernel);
+ TensorRef<Tensor<typename internal::traits<OutputBackward>::Scalar, internal::traits<OutputBackward>::NumDimensions, internal::traits<OutputBackward>::Layout, TensorIndex> > out(output_backward);
+
+ EIGEN_STATIC_ASSERT(internal::traits<Kernel>::Layout == internal::traits<OutputBackward>::Layout, YOU_MADE_A_PROGRAMMING_MISTAKE);
+
+ static const bool isColMajor = (internal::traits<OutputBackward>::Layout == ColMajor);
+
+ static const int NumDims = internal::traits<OutputBackward>::NumDimensions;
+
+ // Number of filters to apply. This is the same as the output depth of the result
+ const TensorIndex kernelFilters = isColMajor ? kern.dimensions()[0] : kern.dimensions()[3];
+ // Number of channels. This is the same as the input depth.
+ const TensorIndex kernelChannels = isColMajor ? kern.dimensions()[1] : kern.dimensions()[2];
+ const TensorIndex kernelRows = isColMajor ? kern.dimensions()[2] : kern.dimensions()[1];
+ const TensorIndex kernelCols = isColMajor ? kern.dimensions()[3] : kern.dimensions()[0];
+
+ // This is the effective kernel size, taking into account the (in_stride - 1) zero-values
+ // inserted between consecutive kernel elements in atrous convolution
+ const TensorIndex kernelRowsEff = kernelRows + (kernelRows - 1) * (in_stride - 1);
+ const TensorIndex kernelColsEff = kernelCols + (kernelCols - 1) * (in_stride - 1);
+
+ const TensorIndex outputRows = isColMajor ? output_backward.dimension(1) : output_backward.dimension(NumDims - 2);
+ const TensorIndex outputCols = isColMajor ? output_backward.dimension(2) : output_backward.dimension(NumDims - 3);
+
+ // Computing the forward padding
+ const TensorIndex forward_pad_top = ((outputRows - 1) * stride + kernelRowsEff - inputRows) / 2;
+ const TensorIndex forward_pad_left = ((outputCols - 1) * stride + kernelColsEff - inputCols) / 2;
+
+ const TensorIndex padding_top = kernelRowsEff - 1 - forward_pad_top;
+ const TensorIndex padding_left = kernelColsEff - 1 - forward_pad_left;
+ const TensorIndex padding_bottom = inputRows + kernelRowsEff - 1 - (outputRows - 1) * stride - 1 - padding_top;
+ const TensorIndex padding_right = inputCols + kernelColsEff - 1 - (outputCols - 1) * stride - 1 - padding_left;
+
+ eigen_assert(padding_top >= 0);
+ eigen_assert(padding_left >= 0);
+ eigen_assert(padding_bottom >= 0);
+ eigen_assert(padding_right >= 0);
+
+ // The kernel has dimensions filters X channels X patch_rows X patch_cols
+ // We need to reverse the kernel along dimensions corresponding to rows and
+ // cols.
+ // TODO(yangke): we can make things slightly faster by collapsing the dimensions
+ // where we don't reverse. Try that once we have a faster compiler.
+ array<bool, 4> kernel_reverse;
+ if (isColMajor) {
+ kernel_reverse[0] = false;
+ kernel_reverse[1] = false;
+ kernel_reverse[2] = true;
+ kernel_reverse[3] = true;
+ } else {
+ kernel_reverse[0] = true;
+ kernel_reverse[1] = true;
+ kernel_reverse[2] = false;
+ kernel_reverse[3] = false;
+ }
+
+ DSizes<TensorIndex, 3> kernel_dims;
+ if (isColMajor) {
+ kernel_dims[0] = kernelFilters;
+ kernel_dims[1] = kernelChannels;
+ kernel_dims[2] = kernelRows * kernelCols;
+ } else {
+ kernel_dims[0] = kernelRows * kernelCols;
+ kernel_dims[1] = kernelChannels;
+ kernel_dims[2] = kernelFilters;
+ }
+
+ // The output_backward has dimensions out_depth X out_rows X out_cols X OTHERS
+ // When we extract the image patches from output_backward, it will have dimensions
+ // out_depth X (patch_rows * patch_cols) X (input_rows * input_cols * OTHERS)
+ DSizes<TensorIndex, 3> pre_contract_dims;
+ if (isColMajor) {
+ pre_contract_dims[0] = kernelFilters;
+ pre_contract_dims[1] = kernelRows * kernelCols;
+ pre_contract_dims[2] = inputRows * inputCols;
+ for (int i = 3; i < NumDims; ++i) {
+ pre_contract_dims[2] *= out.dimension(i);
+ }
+ } else {
+ pre_contract_dims[2] = kernelFilters;
+ pre_contract_dims[1] = kernelRows * kernelCols;
+ pre_contract_dims[0] = inputRows * inputCols;
+ for (int i = 0; i < NumDims - 3; ++i) {
+ pre_contract_dims[0] *= out.dimension(i);
+ }
+ }
+
+ // We will contract along dimensions (0, 2) in kernel and (0, 1) in
+ // output_backward, if this is col-major, and
+ // dimensions (0, 2) in kernel and (1, 2) in output_backward, if this row-major.
+ array<IndexPair<TensorIndex>, 2> contract_dims;
+ if (isColMajor) {
+ // col-major: kernel.contract(output.patches)
+ contract_dims[0] = IndexPair<TensorIndex>(0, 0);
+ contract_dims[1] = IndexPair<TensorIndex>(2, 1);
+ } else {
+ // row-major: output.patches.contract(kernel)
+ contract_dims[0] = IndexPair<TensorIndex>(1, 0);
+ contract_dims[1] = IndexPair<TensorIndex>(2, 2);
+ }
+
+ // Post contraction, the dimensions of the input_backprop is
+ // channels X input_rows X input_cols X OTHERS
+ DSizes<TensorIndex, NumDims> post_contract_dims;
+ if (isColMajor) {
+ post_contract_dims[0] = kernelChannels;
+ post_contract_dims[1] = inputRows;
+ post_contract_dims[2] = inputCols;
+ for (int i = 3; i < NumDims; ++i) {
+ post_contract_dims[i] = out.dimension(i);
+ }
+ } else {
+ post_contract_dims[NumDims - 1] = kernelChannels;
+ post_contract_dims[NumDims - 2] = inputRows;
+ post_contract_dims[NumDims - 3] = inputCols;
+ for (int i = 0; i < NumDims - 3; ++i) {
+ post_contract_dims[i] = out.dimension(i);
+ }
+ }
+
+ return choose(Cond<internal::traits<OutputBackward>::Layout == ColMajor>(),
+ kernel.reverse(kernel_reverse).reshape(kernel_dims).contract(output_backward.extract_image_patches(kernelRows, kernelCols, 1, 1, in_stride, in_stride, stride, stride, padding_top, padding_bottom, padding_left, padding_right, 0).reshape(pre_contract_dims), contract_dims).reshape(post_contract_dims),
+ output_backward.extract_image_patches(kernelRows, kernelCols, 1, 1, in_stride, in_stride, stride, stride, padding_top, padding_bottom, padding_left, padding_right, 0).reshape(pre_contract_dims).contract(kernel.reverse(kernel_reverse).reshape(kernel_dims), contract_dims).reshape(post_contract_dims));
+}
+
+
+/** SpatialConvolutionBackwardKernel
+ * \ingroup CXX11_NeuralNetworks_Module
+ *
+ * \brief Computes the backprop for the filter of a 2D convolution.
+ *
+ * The output_backward parameter is expected to be a tensor with a rank of 3 or more (channels, height, width, and optionally others)
+ * The kernel parameter is expected to be a 4D tensor (filters, channels, kernel_height, kernel_width)
+ * The output_backward and the kernel must both be in col-major layout. The result will also be in col-major layout.
+ *
+ * If in_stride > 1, then applies convolution with holes (aka atrous convolution), sampling every in_stride input pixels.
+ *
+ * The result can be assigned to a tensor of rank equal to the rank of the output_backward. The dimensions of the result will be filters, height, width (and others if applicable).
+ *
+ * It is possible to swap the order of the width and height dimensions provided that the same order is used in the input, the kernel, and the output.
+ *
+ */
+// TODO(gpapan): Resolve a bug in TensorContractionInputMapper at SpatialConvolutions.h that yangke circumvented by using .reshape().reshape().
+// This can significantly accelerate SpatialConvolutionBackwardKernel.
+
+template <typename OutputBackward, typename Input>
+EIGEN_ALWAYS_INLINE
+static const typename internal::conditional<
+ internal::traits<OutputBackward>::Layout == ColMajor,
+ const TensorShufflingOp<const array<typename internal::traits<OutputBackward>::Index, 4>, const TensorReverseOp<const array<bool, 4>, const TensorReshapingOp<const DSizes<typename internal::traits<OutputBackward>::Index, 4>, const TensorContractionOp<const array<IndexPair<typename internal::traits<Input>::Index>, 2>, const TensorReshapingOp<const DSizes<typename internal::traits<Input>::Index, 3>, const Input>, const TensorReshapingOp<const DSizes<typename internal::traits<OutputBackward>::Index, 4>, const TensorReshapingOp<const DSizes<typename internal::traits<OutputBackward>::Index, 4>, const TensorImagePatchOp<Dynamic, Dynamic, const OutputBackward> > > > > > >,
+ const TensorShufflingOp<const array<typename internal::traits<OutputBackward>::Index, 4>, const TensorReverseOp<const array<bool, 4>, const TensorReshapingOp<const DSizes<typename internal::traits<OutputBackward>::Index, 4>, const TensorContractionOp<const array<IndexPair<typename internal::traits<Input>::Index>, 2>, const TensorReshapingOp<const DSizes<typename internal::traits<OutputBackward>::Index, 4>, const TensorReshapingOp<const DSizes<typename internal::traits<OutputBackward>::Index, 4>, const TensorImagePatchOp<Dynamic, Dynamic, const OutputBackward> > >, const TensorReshapingOp<const DSizes<typename internal::traits<Input>::Index, 3>, const Input> > > > > >::type
+SpatialConvolutionBackwardKernel(const Input& input, const OutputBackward& output_backward, typename internal::traits<Input>::Index kernelRows, typename internal::traits<Input>::Index kernelCols, const DenseIndex stride = 1, const DenseIndex in_stride = 1) {
+
+ typedef typename internal::traits<Input>::Index TensorIndex;
+ TensorRef<Tensor<typename internal::traits<Input>::Scalar, internal::traits<Input>::NumDimensions, internal::traits<Input>::Layout, TensorIndex> > in(input);
+ TensorRef<Tensor<typename internal::traits<OutputBackward>::Scalar, internal::traits<OutputBackward>::NumDimensions, internal::traits<OutputBackward>::Layout, TensorIndex> > out(output_backward);
+
+ EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout == internal::traits<OutputBackward>::Layout, YOU_MADE_A_PROGRAMMING_MISTAKE);
+
+ // stride and in_stride cannot both be larger than 1
+ eigen_assert(!(stride > 1 && in_stride > 1));
+
+ static const bool isColMajor = (internal::traits<Input>::Layout == ColMajor);
+
+ static const int NumDims = internal::traits<Input>::NumDimensions;
+ EIGEN_STATIC_ASSERT(internal::traits<Input>::NumDimensions == internal::traits<OutputBackward>::NumDimensions, YOU_MADE_A_PROGRAMMING_MISTAKE);
+
+ const TensorIndex inputRows = isColMajor ? in.dimension(1) : in.dimension(NumDims - 2);
+ const TensorIndex inputCols = isColMajor ? in.dimension(2) : in.dimension(NumDims - 3);
+
+ const TensorIndex outputRows = isColMajor ? output_backward.dimension(1) : output_backward.dimension(NumDims - 2);
+ const TensorIndex outputCols = isColMajor ? output_backward.dimension(2) : output_backward.dimension(NumDims - 3);
+
+ // Number of filters to apply. This is the same as the output depth of the result
+ const TensorIndex kernelFilters = isColMajor ? out.dimensions()[0] : out.dimensions()[NumDims - 1];
+
+ // Number of channels. This is the same as the input depth.
+ const TensorIndex kernelChannels = isColMajor ? in.dimensions()[0] : in.dimensions()[NumDims - 1];
+
+ // This is the effective kernel size, taking into account the (in_stride - 1) zero-values
+ // inserted between consecutive kernel elements in atrous convolution
+ const TensorIndex kernelRowsEff = kernelRows + (kernelRows - 1) * (in_stride - 1);
+ const TensorIndex kernelColsEff = kernelCols + (kernelCols - 1) * (in_stride - 1);
+
+ // Computing the forward padding
+ const TensorIndex forward_pad_top = ((outputRows - 1) * stride + kernelRowsEff - inputRows) / 2;
+ const TensorIndex forward_pad_left = ((outputCols - 1) * stride + kernelColsEff - inputCols) / 2;
+
+ // TODO: factor out the padding computation.
+ const TensorIndex padding_top = kernelRowsEff - 1 - forward_pad_top;
+ const TensorIndex padding_left = kernelColsEff - 1 - forward_pad_left;
+ const TensorIndex padding_bottom = inputRows + kernelRowsEff - 1 - (outputRows - 1) * stride - 1 - padding_top;
+ const TensorIndex padding_right = inputCols + kernelColsEff - 1 - (outputCols - 1) * stride - 1 - padding_left;
+
+ eigen_assert(padding_top >= 0);
+ eigen_assert(padding_left >= 0);
+ eigen_assert(padding_bottom >= 0);
+ eigen_assert(padding_right >= 0);
+
+ // The output_backward has dimensions out_depth X out_rows X out_cols X OTHERS
+ // When we extract the image patches from output_backward (with input as the
+ // kernel), it will have dimensions
+ // (out_depth) X (input_rows * input_cols) X (kernel_rows * kernel_cols) X OTHERS
+ DSizes<TensorIndex, 4> pre_contract_dims;
+ if (isColMajor) {
+ pre_contract_dims[0] = kernelFilters;
+ pre_contract_dims[1] = inputRows * inputCols;
+ pre_contract_dims[2] = kernelRows * kernelCols;
+ pre_contract_dims[3] = 1;
+ for (int i = 3; i < NumDims; ++i) {
+ pre_contract_dims[3] *= out.dimension(i);
+ }
+ } else {
+ pre_contract_dims[3] = kernelFilters;
+ pre_contract_dims[2] = inputRows * inputCols;
+ pre_contract_dims[1] = kernelRows * kernelCols;
+ pre_contract_dims[0] = 1;
+ for (int i = 0; i < NumDims - 3; ++i) {
+ pre_contract_dims[0] *= out.dimension(i);
+ }
+ }
+
+ // The input has dimensions in_depth X (input_rows * input_cols) X OTHERS
+ DSizes<TensorIndex, 3> input_dims;
+ if (isColMajor) {
+ input_dims[0] = kernelChannels;
+ input_dims[1] = inputRows * inputCols;
+ input_dims[2] = 1;
+ for (int i = 3; i < NumDims; ++i) {
+ input_dims[2] *= in.dimension(i);
+ }
+ eigen_assert(input_dims[2] == pre_contract_dims[3]);
+ } else {
+ input_dims[2] = kernelChannels;
+ input_dims[1] = inputRows * inputCols;
+ input_dims[0] = 1;
+ for (int i = 0; i < NumDims - 3; ++i) {
+ input_dims[0] *= in.dimension(i);
+ }
+ eigen_assert(input_dims[0] == pre_contract_dims[0]);
+ }
+
+ // We will contract along dimensions (1, 2) in in and (1, 3) in out, if
+ // this is col-major.
+ // For row-major, it's dimensions (0, 1) in in and (0, 2) in out.
+ array<IndexPair<TensorIndex>, 2> contract_dims;
+ if (isColMajor) {
+ // col-major: in.contract(output.patches)
+ contract_dims[0] = IndexPair<TensorIndex>(1, 1);
+ contract_dims[1] = IndexPair<TensorIndex>(2, 3);
+ } else {
+ // row-major: output.patches.contract(in)
+ contract_dims[0] = IndexPair<TensorIndex>(0, 0);
+ contract_dims[1] = IndexPair<TensorIndex>(2, 1);
+ }
+
+ // After the contraction, the kernel will have dimension
+ // in_depth X out_depth X kernel_rows X kernel_cols
+ // We will need to shuffle the first two dimensions and reverse the latter
+ // two dimensions.
+ // The end shape is
+ // out_depth X in_shape X kernel_rows X kernel_cols
+
+ // This is the shape of the kernel *before* the shuffling.
+ DSizes<TensorIndex, 4> kernel_dims;
+ if (isColMajor) {
+ kernel_dims[0] = kernelChannels;
+ kernel_dims[1] = kernelFilters;
+ kernel_dims[2] = kernelRows;
+ kernel_dims[3] = kernelCols;
+ } else {
+ kernel_dims[0] = kernelCols;
+ kernel_dims[1] = kernelRows;
+ kernel_dims[2] = kernelFilters;
+ kernel_dims[3] = kernelChannels;
+ }
+
+ array<TensorIndex, 4> kernel_shuffle;
+ if (isColMajor) {
+ kernel_shuffle[0] = 1;
+ kernel_shuffle[1] = 0;
+ kernel_shuffle[2] = 2;
+ kernel_shuffle[3] = 3;
+ } else {
+ kernel_shuffle[0] = 0;
+ kernel_shuffle[1] = 1;
+ kernel_shuffle[2] = 3;
+ kernel_shuffle[3] = 2;
+ }
+
+ array<bool, 4> kernel_reverse;
+ if (isColMajor) {
+ kernel_reverse[0] = false;
+ kernel_reverse[1] = false;
+ kernel_reverse[2] = true;
+ kernel_reverse[3] = true;
+ } else {
+ kernel_reverse[0] = true;
+ kernel_reverse[1] = true;
+ kernel_reverse[2] = false;
+ kernel_reverse[3] = false;
+ }
+
+ return choose(Cond<internal::traits<Input>::Layout == ColMajor>(),
+ input.reshape(input_dims).contract(output_backward.extract_image_patches(inputRows, inputCols, in_stride, in_stride, 1, 1, stride, stride, padding_top, padding_bottom, padding_left, padding_right, 0).reshape(pre_contract_dims).reshape(pre_contract_dims), contract_dims).reshape(kernel_dims).reverse(kernel_reverse).shuffle(kernel_shuffle),
+ output_backward.extract_image_patches(inputRows, inputCols, in_stride, in_stride, 1, 1, stride, stride, padding_top, padding_bottom, padding_left, padding_right, 0).reshape(pre_contract_dims).reshape(pre_contract_dims).contract(input.reshape(input_dims), contract_dims).reshape(kernel_dims).reverse(kernel_reverse).shuffle(kernel_shuffle));
+}
+
+} // end namespace Eigen
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_SPATIAL_CONVOLUTIONS_H_
diff --git a/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc b/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc
new file mode 100644
index 0000000000..9e77a71cb5
--- /dev/null
+++ b/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc
@@ -0,0 +1,1959 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/eigen_backward_spatial_convolutions.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace Eigen {
+
+namespace {
+void EigenApprox(float a, float b) {
+ ASSERT_TRUE(std::abs(a - b) <= std::min(std::abs(a), std::abs(b)) * 1e-3);
+}
+static int ceil_div(int a, int b) { return (a + b - 1) / b; }
+}
+
+TEST(EigenBackwardSpatialConvolutionsTest,
+ test_simple_spatial_convolution_backward_input_valid) {
+ const int input_depth = 2;
+ const int input_rows = 3;
+ const int input_cols = 4;
+ const int output_depth = 5;
+ const int patch_rows = 2;
+ const int patch_cols = 2;
+ const int output_rows = input_rows - patch_rows + 1;
+ const int output_cols = input_cols - patch_cols + 1;
+
+ Tensor<float, 3> input_backward(input_depth, input_rows, input_cols);
+ Tensor<float, 4> kernel(output_depth, input_depth, patch_rows, patch_cols);
+ Tensor<float, 3> output_backward(output_depth, output_rows, output_cols);
+
+ output_backward = output_backward.constant(11.0f) + output_backward.random();
+ kernel = kernel.constant(2.0f) + kernel.random();
+ input_backward.setRandom();
+
+ input_backward = SpatialConvolutionBackwardInput(kernel, output_backward,
+ input_rows, input_cols, 1);
+
+ EXPECT_EQ(input_backward.dimension(0), input_depth);
+ EXPECT_EQ(input_backward.dimension(1), input_rows);
+ EXPECT_EQ(input_backward.dimension(2), input_cols);
+
+ for (int id = 0; id < input_depth; ++id) {
+ for (int i = 0; i < input_rows; ++i) {
+ for (int j = 0; j < input_cols; ++j) {
+ float expected = 0.0f;
+ for (int c = 0; c < patch_cols; ++c) {
+ for (int r = 0; r < patch_rows; ++r) {
+ for (int od = 0; od < output_depth; ++od) {
+ int output_i = i - r;
+ int output_j = j - c;
+ if (output_i >= 0 && output_i < output_rows && output_j >= 0 &&
+ output_j < output_cols) {
+ expected += output_backward(od, output_i, output_j) *
+ kernel(od, id, r, c);
+ }
+ }
+ }
+ }
+ EigenApprox(input_backward(id, i, j), expected);
+ }
+ }
+ }
+}
+
+TEST(EigenBackwardSpatialConvolutionsTest,
+ test_simple_spatial_convolution_backward_input_valid_row_major) {
+ const int input_depth = 2;
+ const int input_rows = 3;
+ const int input_cols = 4;
+ const int output_depth = 5;
+ const int patch_rows = 2;
+ const int patch_cols = 2;
+ const int output_rows = input_rows - patch_rows + 1;
+ const int output_cols = input_cols - patch_cols + 1;
+
+ Tensor<float, 3, RowMajor> input_backward(input_cols, input_rows,
+ input_depth);
+ Tensor<float, 4, RowMajor> kernel(patch_cols, patch_rows, input_depth,
+ output_depth);
+ Tensor<float, 3, RowMajor> output_backward(output_cols, output_rows,
+ output_depth);
+
+ output_backward = output_backward.constant(11.0f) + output_backward.random();
+ kernel = kernel.constant(2.0f) + kernel.random();
+ input_backward.setRandom();
+
+ input_backward = SpatialConvolutionBackwardInput(kernel, output_backward,
+ input_rows, input_cols, 1);
+
+ EXPECT_EQ(input_backward.dimension(0), input_cols);
+ EXPECT_EQ(input_backward.dimension(1), input_rows);
+ EXPECT_EQ(input_backward.dimension(2), input_depth);
+
+ for (int id = 0; id < input_depth; ++id) {
+ for (int i = 0; i < input_rows; ++i) {
+ for (int j = 0; j < input_cols; ++j) {
+ float expected = 0.0f;
+ for (int c = 0; c < patch_cols; ++c) {
+ for (int r = 0; r < patch_rows; ++r) {
+ for (int od = 0; od < output_depth; ++od) {
+ int output_i = i - r;
+ int output_j = j - c;
+ if (output_i >= 0 && output_i < output_rows && output_j >= 0 &&
+ output_j < output_cols) {
+ expected += output_backward(output_j, output_i, od) *
+ kernel(c, r, id, od);
+ }
+ }
+ }
+ }
+ EigenApprox(input_backward(j, i, id), expected);
+ }
+ }
+ }
+}
+
+TEST(EigenBackwardSpatialConvolutionsTest,
+ test_simple_cuboid_convolution_backward_input_valid) {
+ const int input_depth = 2;
+ const int input_planes = 5;
+ const int input_rows = 3;
+ const int input_cols = 4;
+ const int patch_rows = 2;
+ const int patch_cols = 2;
+ const int patch_planes = 2;
+ const int output_rows = input_rows - patch_rows + 1;
+ const int output_cols = input_cols - patch_cols + 1;
+ const int output_planes = input_planes - patch_planes + 1;
+ const int output_depth = 5;
+
+ Tensor<float, 4> input_backward(input_depth, input_planes, input_rows,
+ input_cols);
+ Tensor<float, 5> kernel(output_depth, input_depth, patch_planes, patch_rows,
+ patch_cols);
+ Tensor<float, 4> output_backward(output_depth, output_planes, output_rows,
+ output_cols);
+
+ output_backward = output_backward.constant(11.0f) + output_backward.random();
+ kernel = kernel.constant(2.0f) + kernel.random();
+ input_backward.setRandom();
+
+ input_backward = CuboidConvolutionBackwardInput(
+ kernel, output_backward, input_planes, input_rows, input_cols);
+
+ EXPECT_EQ(input_backward.dimension(3), input_cols);
+ EXPECT_EQ(input_backward.dimension(2), input_rows);
+ EXPECT_EQ(input_backward.dimension(1), input_planes);
+ EXPECT_EQ(input_backward.dimension(0), input_depth);
+
+ for (int id = 0; id < input_depth; ++id) {
+ for (int i = 0; i < input_planes; ++i) {
+ for (int j = 0; j < input_rows; ++j) {
+ for (int k = 0; k < input_cols; ++k) {
+ float expected = 0.0f;
+ for (int c = 0; c < patch_cols; ++c) {
+ for (int r = 0; r < patch_rows; ++r) {
+ for (int p = 0; p < patch_planes; ++p) {
+ for (int od = 0; od < output_depth; ++od) {
+ int output_j = j - r;
+ int output_k = k - c;
+ int output_i = i - p;
+ if (output_i >= 0 && output_i < output_planes &&
+ output_j >= 0 && output_j < output_rows &&
+ output_k >= 0 && output_k < output_cols) {
+ expected +=
+ output_backward(od, output_i, output_j, output_k) *
+ kernel(od, id, p, r, c);
+ }
+ }
+ }
+ }
+ }
+ EigenApprox(input_backward(id, i, j, k), expected);
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenBackwardSpatialConvolutionsTest,
+ test_simple_cuboid_convolution_backward_input_valid_row_major) {
+ const int input_depth = 2;
+ const int input_planes = 5;
+ const int input_rows = 3;
+ const int input_cols = 4;
+ const int patch_rows = 2;
+ const int patch_cols = 2;
+ const int patch_planes = 2;
+ const int output_rows = input_rows - patch_rows + 1;
+ const int output_cols = input_cols - patch_cols + 1;
+ const int output_planes = input_planes - patch_planes + 1;
+ const int output_depth = 5;
+
+ Tensor<float, 4, RowMajor> input_backward(input_cols, input_rows,
+ input_planes, input_depth);
+ Tensor<float, 5, RowMajor> kernel(patch_cols, patch_rows, patch_planes,
+ input_depth, output_depth);
+ Tensor<float, 4, RowMajor> output_backward(output_cols, output_rows,
+ output_planes, output_depth);
+
+ output_backward = output_backward.constant(11.0f) + output_backward.random();
+ kernel = kernel.constant(2.0f) + kernel.random();
+ input_backward.setRandom();
+
+ input_backward = CuboidConvolutionBackwardInput(
+ kernel, output_backward, input_planes, input_rows, input_cols);
+
+ EXPECT_EQ(input_backward.dimension(0), input_cols);
+ EXPECT_EQ(input_backward.dimension(1), input_rows);
+ EXPECT_EQ(input_backward.dimension(2), input_planes);
+ EXPECT_EQ(input_backward.dimension(3), input_depth);
+
+ for (int id = 0; id < input_depth; ++id) {
+ for (int i = 0; i < input_planes; ++i) {
+ for (int j = 0; j < input_rows; ++j) {
+ for (int k = 0; k < input_cols; ++k) {
+ float expected = 0.0f;
+ for (int c = 0; c < patch_cols; ++c) {
+ for (int r = 0; r < patch_rows; ++r) {
+ for (int p = 0; p < patch_planes; ++p) {
+ for (int od = 0; od < output_depth; ++od) {
+ int output_j = j - r;
+ int output_k = k - c;
+ int output_i = i - p;
+ if (output_i >= 0 && output_i < output_planes &&
+ output_j >= 0 && output_j < output_rows &&
+ output_k >= 0 && output_k < output_cols) {
+ expected +=
+ output_backward(output_k, output_j, output_i, od) *
+ kernel(c, r, p, id, od);
+ }
+ }
+ }
+ }
+ }
+ EigenApprox(input_backward(k, j, i, id), expected);
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenBackwardSpatialConvolutionsTest,
+ test_simple_spatial_convolution_backward_input_same) {
+ const int input_depth = 2;
+ const int input_rows = 7;
+ const int input_cols = 9;
+ const int output_depth = 3;
+ const int patch_rows = 4;
+ const int patch_cols = 4;
+ const int output_rows = input_rows;
+ const int output_cols = input_cols;
+
+ Tensor<float, 3> input_backward(input_depth, input_rows, input_cols);
+ Tensor<float, 4> kernel(output_depth, input_depth, patch_rows, patch_cols);
+ Tensor<float, 3> output_backward(output_depth, output_rows, output_cols);
+
+ output_backward = output_backward.constant(11.0f) + output_backward.random();
+ kernel = kernel.constant(2.0f) + kernel.random();
+
+ input_backward = SpatialConvolutionBackwardInput(kernel, output_backward,
+ input_rows, input_cols, 1);
+
+ EXPECT_EQ(input_backward.dimension(0), input_depth);
+ EXPECT_EQ(input_backward.dimension(1), input_rows);
+ EXPECT_EQ(input_backward.dimension(2), input_cols);
+
+ for (int id = 0; id < input_depth; ++id) {
+ for (int i = 0; i < input_rows; ++i) {
+ for (int j = 0; j < input_cols; ++j) {
+ float expected = 0.0f;
+ for (int c = 0; c < patch_cols; ++c) {
+ for (int r = 0; r < patch_rows; ++r) {
+ for (int od = 0; od < output_depth; ++od) {
+ int output_i = i - r + (patch_rows - 1) / 2;
+ int output_j = j - c + (patch_cols - 1) / 2;
+ if (output_i >= 0 && output_i < output_rows && output_j >= 0 &&
+ output_j < output_cols) {
+ expected += output_backward(od, output_i, output_j) *
+ kernel(od, id, r, c);
+ }
+ }
+ }
+ }
+ EigenApprox(input_backward(id, i, j), expected);
+ }
+ }
+ }
+}
+
+TEST(EigenBackwardSpatialConvolutionsTest,
+ test_simple_spatial_convolution_backward_input_same_row_major) {
+ const int input_depth = 2;
+ const int input_rows = 7;
+ const int input_cols = 9;
+ const int output_depth = 3;
+ const int patch_rows = 4;
+ const int patch_cols = 4;
+ const int output_rows = input_rows;
+ const int output_cols = input_cols;
+
+ Tensor<float, 3, RowMajor> input_backward(input_cols, input_rows,
+ input_depth);
+ Tensor<float, 4, RowMajor> kernel(patch_cols, patch_rows, input_depth,
+ output_depth);
+ Tensor<float, 3, RowMajor> output_backward(output_cols, output_rows,
+ output_depth);
+
+ output_backward = output_backward.constant(11.0f) + output_backward.random();
+ kernel = kernel.constant(2.0f) + kernel.random();
+
+ input_backward = SpatialConvolutionBackwardInput(kernel, output_backward,
+ input_rows, input_cols, 1);
+
+ EXPECT_EQ(input_backward.dimension(0), input_cols);
+ EXPECT_EQ(input_backward.dimension(1), input_rows);
+ EXPECT_EQ(input_backward.dimension(2), input_depth);
+
+ for (int id = 0; id < input_depth; ++id) {
+ for (int i = 0; i < input_rows; ++i) {
+ for (int j = 0; j < input_cols; ++j) {
+ float expected = 0.0f;
+ for (int c = 0; c < patch_cols; ++c) {
+ for (int r = 0; r < patch_rows; ++r) {
+ for (int od = 0; od < output_depth; ++od) {
+ int output_i = i - r + (patch_rows - 1) / 2;
+ int output_j = j - c + (patch_cols - 1) / 2;
+ if (output_i >= 0 && output_i < output_rows && output_j >= 0 &&
+ output_j < output_cols) {
+ expected += output_backward(output_j, output_i, od) *
+ kernel(c, r, id, od);
+ }
+ }
+ }
+ }
+ EigenApprox(input_backward(j, i, id), expected);
+ }
+ }
+ }
+}
+
+TEST(EigenBackwardSpatialConvolutionsTest,
+ test_simple_cuboid_convolution_backward_input_same) {
+ const int input_depth = 2;
+ const int input_planes = 5;
+ const int input_rows = 3;
+ const int input_cols = 4;
+ const int patch_rows = 3;
+ const int patch_cols = 2;
+ const int patch_planes = 4;
+ const int output_rows = input_rows;
+ const int output_cols = input_cols;
+ const int output_planes = input_planes;
+ const int output_depth = 5;
+
+ Tensor<float, 4> input_backward(input_depth, input_planes, input_rows,
+ input_cols);
+ Tensor<float, 5> kernel(output_depth, input_depth, patch_planes, patch_rows,
+ patch_cols);
+ Tensor<float, 4> output_backward(output_depth, output_planes, output_rows,
+ output_cols);
+
+ output_backward = output_backward.constant(11.0f) + output_backward.random();
+ kernel = kernel.constant(2.0f) + kernel.random();
+ input_backward.setRandom();
+
+ input_backward = CuboidConvolutionBackwardInput(
+ kernel, output_backward, input_planes, input_rows, input_cols);
+
+ EXPECT_EQ(input_backward.dimension(3), input_cols);
+ EXPECT_EQ(input_backward.dimension(2), input_rows);
+ EXPECT_EQ(input_backward.dimension(1), input_planes);
+ EXPECT_EQ(input_backward.dimension(0), input_depth);
+
+ const int dz = patch_planes - 1;
+ const int dy = patch_rows - 1;
+ const int dx = patch_cols - 1;
+
+ const int forward_pad_x = dx - dx / 2;
+ const int forward_pad_y = dy - dy / 2;
+ const int forward_pad_z = dz - dz / 2;
+
+ for (int id = 0; id < input_depth; ++id) {
+ for (int i = 0; i < input_planes; ++i) {
+ for (int j = 0; j < input_rows; ++j) {
+ for (int k = 0; k < input_cols; ++k) {
+ float expected = 0.0f;
+ for (int c = 0; c < patch_cols; ++c) {
+ for (int r = 0; r < patch_rows; ++r) {
+ for (int p = 0; p < patch_planes; ++p) {
+ for (int od = 0; od < output_depth; ++od) {
+ int output_i = i - p + forward_pad_z;
+ int output_j = j - r + forward_pad_y;
+ int output_k = k - c + forward_pad_x;
+ if (output_i >= 0 && output_i < output_planes &&
+ output_j >= 0 && output_j < output_rows &&
+ output_k >= 0 && output_k < output_cols) {
+ expected +=
+ output_backward(od, output_i, output_j, output_k) *
+ kernel(od, id, p, r, c);
+ }
+ }
+ }
+ }
+ }
+ EigenApprox(input_backward(id, i, j, k), expected);
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenBackwardSpatialConvolutionsTest,
+ test_simple_cuboid_convolution_backward_input_same_row_major) {
+ const int input_depth = 2;
+ const int input_planes = 5;
+ const int input_rows = 3;
+ const int input_cols = 4;
+ const int patch_rows = 2;
+ const int patch_cols = 3;
+ const int patch_planes = 4;
+ const int output_rows = input_rows;
+ const int output_cols = input_cols;
+ const int output_planes = input_planes;
+ const int output_depth = 5;
+
+ Tensor<float, 4, RowMajor> input_backward(input_cols, input_rows,
+ input_planes, input_depth);
+ Tensor<float, 5, RowMajor> kernel(patch_cols, patch_rows, patch_planes,
+ input_depth, output_depth);
+ Tensor<float, 4, RowMajor> output_backward(output_cols, output_rows,
+ output_planes, output_depth);
+
+ output_backward = output_backward.constant(11.0f) + output_backward.random();
+ kernel = kernel.constant(2.0f) + kernel.random();
+ input_backward.setRandom();
+
+ input_backward = CuboidConvolutionBackwardInput(
+ kernel, output_backward, input_planes, input_rows, input_cols);
+
+ EXPECT_EQ(input_backward.dimension(0), input_cols);
+ EXPECT_EQ(input_backward.dimension(1), input_rows);
+ EXPECT_EQ(input_backward.dimension(2), input_planes);
+ EXPECT_EQ(input_backward.dimension(3), input_depth);
+
+ const int dz = patch_planes - 1;
+ const int dy = patch_rows - 1;
+ const int dx = patch_cols - 1;
+
+ const int forward_pad_x = dx - dx / 2;
+ const int forward_pad_y = dy - dy / 2;
+ const int forward_pad_z = dz - dz / 2;
+
+ for (int id = 0; id < input_depth; ++id) {
+ for (int i = 0; i < input_planes; ++i) {
+ for (int j = 0; j < input_rows; ++j) {
+ for (int k = 0; k < input_cols; ++k) {
+ float expected = 0.0f;
+ for (int c = 0; c < patch_cols; ++c) {
+ for (int r = 0; r < patch_rows; ++r) {
+ for (int p = 0; p < patch_planes; ++p) {
+ for (int od = 0; od < output_depth; ++od) {
+ int output_i = i - p + forward_pad_z;
+ int output_j = j - r + forward_pad_y;
+ int output_k = k - c + forward_pad_x;
+ if (output_i >= 0 && output_i < output_planes &&
+ output_j >= 0 && output_j < output_rows &&
+ output_k >= 0 && output_k < output_cols) {
+ expected +=
+ output_backward(output_k, output_j, output_i, od) *
+ kernel(c, r, p, id, od);
+ }
+ }
+ }
+ }
+ }
+ EigenApprox(input_backward(k, j, i, id), expected);
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenBackwardSpatialConvolutionsTest,
+ test_batched_spatial_convolution_backward_input_valid) {
+ const int num_batches = 13;
+ const int input_depth = 2;
+ const int input_rows = 7;
+ const int input_cols = 9;
+ const int output_depth = 3;
+ const int patch_rows = 5;
+ const int patch_cols = 5;
+ const int output_rows = input_rows - patch_rows + 1;
+ const int output_cols = input_cols - patch_cols + 1;
+
+ Tensor<float, 4> input_backward(input_depth, input_rows, input_cols,
+ num_batches);
+ Tensor<float, 4> kernel(output_depth, input_depth, patch_rows, patch_cols);
+ Tensor<float, 4> output_backward(output_depth, output_rows, output_cols,
+ num_batches);
+
+ output_backward = output_backward.constant(11.0f) + output_backward.random();
+ kernel = kernel.constant(2.0f) + kernel.random();
+ input_backward.setRandom();
+
+ input_backward = SpatialConvolutionBackwardInput(kernel, output_backward,
+ input_rows, input_cols, 1);
+
+ EXPECT_EQ(input_backward.dimension(0), input_depth);
+ EXPECT_EQ(input_backward.dimension(1), input_rows);
+ EXPECT_EQ(input_backward.dimension(2), input_cols);
+ EXPECT_EQ(input_backward.dimension(3), num_batches);
+
+ for (int b = 0; b < num_batches; ++b) {
+ for (int id = 0; id < input_depth; ++id) {
+ for (int i = 0; i < input_rows; ++i) {
+ for (int j = 0; j < input_cols; ++j) {
+ float expected = 0.0f;
+ for (int c = 0; c < patch_cols; ++c) {
+ for (int r = 0; r < patch_rows; ++r) {
+ for (int od = 0; od < output_depth; ++od) {
+ int output_i = i - r;
+ int output_j = j - c;
+ if (output_i >= 0 && output_i < output_rows && output_j >= 0 &&
+ output_j < output_cols) {
+ expected += output_backward(od, output_i, output_j, b) *
+ kernel(od, id, r, c);
+ }
+ }
+ }
+ }
+ EigenApprox(input_backward(id, i, j, b), expected);
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenBackwardSpatialConvolutionsTest,
+ test_batched_spatial_convolution_backward_input_valid_row_major) {
+ const int num_batches = 13;
+ const int input_depth = 2;
+ const int input_rows = 7;
+ const int input_cols = 9;
+ const int output_depth = 3;
+ const int patch_rows = 5;
+ const int patch_cols = 5;
+ const int output_rows = input_rows - patch_rows + 1;
+ const int output_cols = input_cols - patch_cols + 1;
+
+ Tensor<float, 4, RowMajor> input_backward(num_batches, input_cols, input_rows,
+ input_depth);
+ Tensor<float, 4, RowMajor> kernel(patch_cols, patch_rows, input_depth,
+ output_depth);
+ Tensor<float, 4, RowMajor> output_backward(num_batches, output_cols,
+ output_rows, output_depth);
+
+ output_backward = output_backward.constant(11.0f) + output_backward.random();
+ kernel = kernel.constant(2.0f) + kernel.random();
+ input_backward.setRandom();
+
+ input_backward = SpatialConvolutionBackwardInput(kernel, output_backward,
+ input_rows, input_cols, 1);
+
+ EXPECT_EQ(input_backward.dimension(0), num_batches);
+ EXPECT_EQ(input_backward.dimension(1), input_cols);
+ EXPECT_EQ(input_backward.dimension(2), input_rows);
+ EXPECT_EQ(input_backward.dimension(3), input_depth);
+
+ for (int b = 0; b < num_batches; ++b) {
+ for (int id = 0; id < input_depth; ++id) {
+ for (int i = 0; i < input_rows; ++i) {
+ for (int j = 0; j < input_cols; ++j) {
+ float expected = 0.0f;
+ for (int c = 0; c < patch_cols; ++c) {
+ for (int r = 0; r < patch_rows; ++r) {
+ for (int od = 0; od < output_depth; ++od) {
+ int output_i = i - r;
+ int output_j = j - c;
+ if (output_i >= 0 && output_i < output_rows && output_j >= 0 &&
+ output_j < output_cols) {
+ expected += output_backward(b, output_j, output_i, od) *
+ kernel(c, r, id, od);
+ }
+ }
+ }
+ }
+ EigenApprox(input_backward(b, j, i, id), expected);
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenBackwardSpatialConvolutionsTest,
+ test_batched_cuboid_convolution_backward_input_valid) {
+ const int num_batches = 13;
+ const int input_depth = 2;
+ const int input_planes = 5;
+ const int input_rows = 3;
+ const int input_cols = 4;
+ const int patch_rows = 2;
+ const int patch_cols = 2;
+ const int patch_planes = 2;
+ const int output_rows = input_rows - patch_rows + 1;
+ const int output_cols = input_cols - patch_cols + 1;
+ const int output_planes = input_planes - patch_planes + 1;
+ const int output_depth = 5;
+
+ Tensor<float, 5> input_backward(input_depth, input_planes, input_rows,
+ input_cols, num_batches);
+ Tensor<float, 5> kernel(output_depth, input_depth, patch_planes, patch_rows,
+ patch_cols);
+ Tensor<float, 5> output_backward(output_depth, output_planes, output_rows,
+ output_cols, num_batches);
+
+ output_backward = output_backward.constant(11.0f) + output_backward.random();
+ kernel = kernel.constant(2.0f) + kernel.random();
+ input_backward.setRandom();
+
+ input_backward = CuboidConvolutionBackwardInput(
+ kernel, output_backward, input_planes, input_rows, input_cols);
+
+ EXPECT_EQ(input_backward.dimension(4), num_batches);
+ EXPECT_EQ(input_backward.dimension(3), input_cols);
+ EXPECT_EQ(input_backward.dimension(2), input_rows);
+ EXPECT_EQ(input_backward.dimension(1), input_planes);
+ EXPECT_EQ(input_backward.dimension(0), input_depth);
+
+ for (int b = 0; b < num_batches; ++b) {
+ for (int id = 0; id < input_depth; ++id) {
+ for (int i = 0; i < input_planes; ++i) {
+ for (int j = 0; j < input_rows; ++j) {
+ for (int k = 0; k < input_cols; ++k) {
+ float expected = 0.0f;
+ for (int c = 0; c < patch_cols; ++c) {
+ for (int r = 0; r < patch_rows; ++r) {
+ for (int p = 0; p < patch_planes; ++p) {
+ for (int od = 0; od < output_depth; ++od) {
+ int output_i = i - p;
+ int output_j = j - r;
+ int output_k = k - c;
+ if (output_i >= 0 && output_i < output_planes &&
+ output_j >= 0 && output_j < output_rows &&
+ output_k >= 0 && output_k < output_cols) {
+ expected +=
+ output_backward(od, output_i, output_j, output_k, b) *
+ kernel(od, id, p, r, c);
+ }
+ }
+ }
+ }
+ }
+ EigenApprox(input_backward(id, i, j, k, b), expected);
+ }
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenBackwardSpatialConvolutionsTest,
+ test_batched_cuboid_convolution_backward_input_valid_row_major) {
+ const int num_batches = 13;
+ const int input_depth = 2;
+ const int input_planes = 5;
+ const int input_rows = 3;
+ const int input_cols = 4;
+ const int patch_rows = 2;
+ const int patch_cols = 2;
+ const int patch_planes = 2;
+ const int output_rows = input_rows - patch_rows + 1;
+ const int output_cols = input_cols - patch_cols + 1;
+ const int output_planes = input_planes - patch_planes + 1;
+ const int output_depth = 5;
+
+ Tensor<float, 5, RowMajor> input_backward(num_batches, input_cols, input_rows,
+ input_planes, input_depth);
+ Tensor<float, 5, RowMajor> kernel(patch_cols, patch_rows, patch_planes,
+ input_depth, output_depth);
+ Tensor<float, 5, RowMajor> output_backward(
+ num_batches, output_cols, output_rows, output_planes, output_depth);
+
+ output_backward = output_backward.constant(11.0f) + output_backward.random();
+ kernel = kernel.constant(2.0f) + kernel.random();
+ input_backward.setRandom();
+
+ input_backward = CuboidConvolutionBackwardInput(
+ kernel, output_backward, input_planes, input_rows, input_cols);
+
+ EXPECT_EQ(input_backward.dimension(0), num_batches);
+ EXPECT_EQ(input_backward.dimension(1), input_cols);
+ EXPECT_EQ(input_backward.dimension(2), input_rows);
+ EXPECT_EQ(input_backward.dimension(3), input_planes);
+ EXPECT_EQ(input_backward.dimension(4), input_depth);
+
+ for (int b = 0; b < num_batches; ++b) {
+ for (int id = 0; id < input_depth; ++id) {
+ for (int i = 0; i < input_planes; ++i) {
+ for (int j = 0; j < input_rows; ++j) {
+ for (int k = 0; k < input_cols; ++k) {
+ float expected = 0.0f;
+ for (int c = 0; c < patch_cols; ++c) {
+ for (int r = 0; r < patch_rows; ++r) {
+ for (int p = 0; p < patch_planes; ++p) {
+ for (int od = 0; od < output_depth; ++od) {
+ int output_i = i - p;
+ int output_j = j - r;
+ int output_k = k - c;
+ if (output_i >= 0 && output_i < output_planes &&
+ output_j >= 0 && output_j < output_rows &&
+ output_k >= 0 && output_k < output_cols) {
+ expected +=
+ output_backward(b, output_k, output_j, output_i, od) *
+ kernel(c, r, p, id, od);
+ }
+ }
+ }
+ }
+ }
+ EigenApprox(input_backward(b, k, j, i, id), expected);
+ }
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenBackwardSpatialConvolutionsTest,
+ test_batched_strided_spatial_convolution_backward_input_valid) {
+ const int num_batches = 11;
+ const int input_depth = 2;
+ const int input_rows = 9;
+ const int input_cols = 13;
+ const int output_depth = 5;
+ const int patch_rows = 3;
+ const int patch_cols = 3;
+
+ const int stride = 3;
+
+ const int output_rows = (input_rows - patch_rows + 1 + stride - 1) / stride;
+ const int output_cols = (input_cols - patch_cols + 1 + stride - 1) / stride;
+
+ Tensor<float, 4> input_backward(input_depth, input_rows, input_cols,
+ num_batches);
+ Tensor<float, 4> kernel(output_depth, input_depth, patch_rows, patch_cols);
+ Tensor<float, 4> output_backward(output_depth, output_rows, output_cols,
+ num_batches);
+
+ output_backward = output_backward.constant(11.0f) + output_backward.random();
+ kernel = kernel.constant(2.0f) + kernel.random();
+ input_backward.setRandom();
+
+ input_backward = SpatialConvolutionBackwardInput(
+ kernel, output_backward, input_rows, input_cols, stride);
+
+ EXPECT_EQ(input_backward.dimension(0), input_depth);
+ EXPECT_EQ(input_backward.dimension(1), input_rows);
+ EXPECT_EQ(input_backward.dimension(2), input_cols);
+ EXPECT_EQ(input_backward.dimension(3), num_batches);
+
+ for (int b = 0; b < num_batches; ++b) {
+ for (int id = 0; id < input_depth; ++id) {
+ for (int i = 0; i < input_rows; ++i) {
+ for (int j = 0; j < input_cols; ++j) {
+ float expected = 0.0f;
+ for (int c = 0; c < patch_cols; ++c) {
+ for (int r = 0; r < patch_rows; ++r) {
+ for (int od = 0; od < output_depth; ++od) {
+ int output_i = i - r;
+ int output_j = j - c;
+ if (output_i >= 0 && output_i / stride < output_rows &&
+ output_j >= 0 && output_j / stride < output_cols &&
+ output_i % stride == 0 && output_j % stride == 0) {
+ expected += output_backward(od, output_i / stride,
+ output_j / stride, b) *
+ kernel(od, id, r, c);
+ }
+ }
+ }
+ }
+ EigenApprox(input_backward(id, i, j, b), expected);
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenBackwardSpatialConvolutionsTest,
+ test_batched_strided_spatial_convolution_backward_input_valid_row_major) {
+ const int num_batches = 11;
+ const int input_depth = 3;
+ const int input_rows = 5;
+ const int input_cols = 9;
+ const int output_depth = 1;
+ const int patch_rows = 3;
+ const int patch_cols = 3;
+
+ const int stride = 2;
+
+ const int output_rows = (input_rows - patch_rows + 2) / stride;
+ const int output_cols = (input_cols - patch_cols + 2) / stride;
+
+ Tensor<float, 4, RowMajor> input_backward(num_batches, input_cols, input_rows,
+ input_depth);
+ Tensor<float, 4, RowMajor> kernel(patch_cols, patch_rows, input_depth,
+ output_depth);
+ Tensor<float, 4, RowMajor> output_backward(num_batches, output_cols,
+ output_rows, output_depth);
+
+ output_backward = output_backward.constant(11.0f) + output_backward.random();
+ kernel = kernel.constant(2.0f) + kernel.random();
+ input_backward.setRandom();
+
+ input_backward = SpatialConvolutionBackwardInput(
+ kernel, output_backward, input_rows, input_cols, stride);
+
+ EXPECT_EQ(input_backward.dimension(0), num_batches);
+ EXPECT_EQ(input_backward.dimension(1), input_cols);
+ EXPECT_EQ(input_backward.dimension(2), input_rows);
+ EXPECT_EQ(input_backward.dimension(3), input_depth);
+
+ for (int b = 0; b < num_batches; ++b) {
+ for (int id = 0; id < input_depth; ++id) {
+ for (int i = 0; i < input_rows; ++i) {
+ for (int j = 0; j < input_cols; ++j) {
+ float expected = 0.0f;
+ for (int c = 0; c < patch_cols; ++c) {
+ for (int r = 0; r < patch_rows; ++r) {
+ for (int od = 0; od < output_depth; ++od) {
+ int output_i = i - r;
+ int output_j = j - c;
+ if (output_i >= 0 && output_i / stride < output_rows &&
+ output_j >= 0 && output_j / stride < output_cols &&
+ output_i % stride == 0 && output_j % stride == 0) {
+ expected += output_backward(b, output_j / stride,
+ output_i / stride, od) *
+ kernel(c, r, id, od);
+ }
+ }
+ }
+ }
+ EigenApprox(input_backward(b, j, i, id), expected);
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenBackwardSpatialConvolutionsTest,
+ test_simple_spatial_convolution_backward_kernel_valid) {
+ const int input_depth = 2;
+ const int input_rows = 3;
+ const int input_cols = 4;
+ const int output_depth = 5;
+ const int patch_rows = 2;
+ const int patch_cols = 2;
+ const int output_rows = input_rows - patch_rows + 1;
+ const int output_cols = input_cols - patch_cols + 1;
+
+ Tensor<float, 3> input(input_depth, input_rows, input_cols);
+ Tensor<float, 4> kernel(output_depth, input_depth, patch_rows, patch_cols);
+ Tensor<float, 3> output_backward(output_depth, output_rows, output_cols);
+
+ output_backward = output_backward.constant(11.0f) + output_backward.random();
+ input = input.constant(2.0f) + input.random();
+ kernel.setRandom();
+
+ kernel = SpatialConvolutionBackwardKernel(input, output_backward, patch_rows,
+ patch_cols, 1);
+
+ EXPECT_EQ(kernel.dimension(0), output_depth);
+ EXPECT_EQ(kernel.dimension(1), input_depth);
+ EXPECT_EQ(kernel.dimension(2), patch_rows);
+ EXPECT_EQ(kernel.dimension(3), patch_cols);
+
+ for (int od = 0; od < output_depth; ++od) {
+ for (int id = 0; id < input_depth; ++id) {
+ for (int r = 0; r < patch_rows; ++r) {
+ for (int c = 0; c < patch_cols; ++c) {
+ float expected = 0.0f;
+ for (int i = 0; i < input_rows; ++i) {
+ for (int j = 0; j < input_cols; ++j) {
+ int output_i = i - r;
+ int output_j = j - c;
+ if (output_i >= 0 && output_i < output_rows && output_j >= 0 &&
+ output_j < output_cols) {
+ expected +=
+ input(id, i, j) * output_backward(od, output_i, output_j);
+ }
+ }
+ }
+ EigenApprox(kernel(od, id, r, c), expected);
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenBackwardSpatialConvolutionsTest,
+ test_simple_spatial_convolution_backward_kernel_valid_row_major) {
+ const int input_depth = 2;
+ const int input_rows = 3;
+ const int input_cols = 4;
+ const int output_depth = 5;
+ const int patch_rows = 2;
+ const int patch_cols = 2;
+ const int output_rows = input_rows - patch_rows + 1;
+ const int output_cols = input_cols - patch_cols + 1;
+
+ Tensor<float, 3, RowMajor> input(input_cols, input_rows, input_depth);
+ Tensor<float, 4, RowMajor> kernel(patch_cols, patch_rows, input_depth,
+ output_depth);
+ Tensor<float, 3, RowMajor> output_backward(output_cols, output_rows,
+ output_depth);
+
+ output_backward = output_backward.constant(11.0f) + output_backward.random();
+ input = input.constant(2.0f) + input.random();
+ kernel.setRandom();
+
+ kernel = SpatialConvolutionBackwardKernel(input, output_backward, patch_rows,
+ patch_cols, 1);
+
+ EXPECT_EQ(kernel.dimension(0), patch_cols);
+ EXPECT_EQ(kernel.dimension(1), patch_rows);
+ EXPECT_EQ(kernel.dimension(2), input_depth);
+ EXPECT_EQ(kernel.dimension(3), output_depth);
+
+ for (int od = 0; od < output_depth; ++od) {
+ for (int id = 0; id < input_depth; ++id) {
+ for (int r = 0; r < patch_rows; ++r) {
+ for (int c = 0; c < patch_cols; ++c) {
+ float expected = 0.0f;
+ for (int i = 0; i < input_rows; ++i) {
+ for (int j = 0; j < input_cols; ++j) {
+ int output_i = i - r;
+ int output_j = j - c;
+ if (output_i >= 0 && output_i < output_rows && output_j >= 0 &&
+ output_j < output_cols) {
+ expected +=
+ input(j, i, id) * output_backward(output_j, output_i, od);
+ }
+ }
+ }
+ EigenApprox(kernel(c, r, id, od), expected);
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenBackwardSpatialConvolutionsTest,
+ test_batched_atrous_spatial_convolution_backward_input_valid) {
+ const int num_batches = 11;
+ const int patch_rows = 3;
+ const int patch_cols = 3;
+
+ const int input_depth = 2;
+ const int input_rows = 9;
+ const int input_cols = 13;
+
+ const int in_stride = 3;
+ const int patch_rows_eff = patch_rows + (patch_rows - 1) * (in_stride - 1);
+ const int patch_cols_eff = patch_cols + (patch_cols - 1) * (in_stride - 1);
+
+ const int output_depth = 5;
+ const int output_rows = input_rows - patch_rows_eff + 1;
+ const int output_cols = input_cols - patch_cols_eff + 1;
+
+ Tensor<float, 4> output_backward(output_depth, output_rows, output_cols,
+ num_batches);
+ output_backward.setRandom();
+ Tensor<float, 4> kernel(output_depth, input_depth, patch_rows, patch_cols);
+ kernel.setRandom();
+
+ const array<DenseIndex, 4> kernel_strides({1, 1, in_stride, in_stride});
+ const Tensor<float, 4> kernel_eff = kernel.inflate(kernel_strides);
+
+ const Tensor<float, 4> input_backward = SpatialConvolutionBackwardInput(
+ kernel, output_backward, input_rows, input_cols, 1, in_stride);
+ const Tensor<float, 4> expected_input_backward =
+ SpatialConvolutionBackwardInput(kernel_eff, output_backward, input_rows,
+ input_cols);
+
+ EXPECT_EQ(input_backward.dimension(0), input_depth);
+ EXPECT_EQ(input_backward.dimension(1), input_rows);
+ EXPECT_EQ(input_backward.dimension(2), input_cols);
+ EXPECT_EQ(input_backward.dimension(3), num_batches);
+
+ eigen_assert(dimensions_match(input_backward.dimensions(),
+ expected_input_backward.dimensions()));
+ for (size_t i = 0; i < input_backward.dimensions().TotalSize(); ++i) {
+ EigenApprox(input_backward.data()[i], expected_input_backward.data()[i]);
+ }
+}
+
+TEST(EigenBackwardSpatialConvolutionsTest,
+ test_batched_atrous_spatial_convolution_backward_input_valid_row_major) {
+ const int num_batches = 11;
+ const int patch_rows = 3;
+ const int patch_cols = 3;
+
+ const int input_depth = 2;
+ const int input_rows = 9;
+ const int input_cols = 13;
+
+ const int in_stride = 3;
+ const int patch_rows_eff = patch_rows + (patch_rows - 1) * (in_stride - 1);
+ const int patch_cols_eff = patch_cols + (patch_cols - 1) * (in_stride - 1);
+
+ const int output_depth = 5;
+ const int output_rows = input_rows - patch_rows_eff + 1;
+ const int output_cols = input_cols - patch_cols_eff + 1;
+
+ Tensor<float, 4, RowMajor> output_backward(num_batches, output_cols,
+ output_rows, output_depth);
+ output_backward.setRandom();
+
+ Tensor<float, 4, RowMajor> kernel(patch_cols, patch_rows, input_depth,
+ output_depth);
+ kernel.setRandom();
+
+ const array<DenseIndex, 4> kernel_strides({in_stride, in_stride, 1, 1});
+ const Tensor<float, 4, RowMajor> kernel_eff = kernel.inflate(kernel_strides);
+
+ const Tensor<float, 4, RowMajor> input_backward =
+ SpatialConvolutionBackwardInput(kernel, output_backward, input_rows,
+ input_cols, 1, in_stride);
+ const Tensor<float, 4, RowMajor> expected_input_backward =
+ SpatialConvolutionBackwardInput(kernel_eff, output_backward, input_rows,
+ input_cols);
+
+ EXPECT_EQ(input_backward.dimension(0), num_batches);
+ EXPECT_EQ(input_backward.dimension(1), input_cols);
+ EXPECT_EQ(input_backward.dimension(2), input_rows);
+ EXPECT_EQ(input_backward.dimension(3), input_depth);
+
+ eigen_assert(dimensions_match(input_backward.dimensions(),
+ expected_input_backward.dimensions()));
+ for (size_t i = 0; i < input_backward.dimensions().TotalSize(); ++i) {
+ EigenApprox(input_backward.data()[i], expected_input_backward.data()[i]);
+ }
+}
+
+TEST(EigenBackwardSpatialConvolutionsTest,
+ test_batched_atrous_spatial_convolution_backward_kernel_valid) {
+ const int num_batches = 11;
+ const int patch_rows = 3;
+ const int patch_cols = 3;
+
+ const int input_depth = 2;
+ const int input_rows = 9;
+ const int input_cols = 13;
+
+ const int in_stride = 3;
+ const int patch_rows_eff = patch_rows + (patch_rows - 1) * (in_stride - 1);
+ const int patch_cols_eff = patch_cols + (patch_cols - 1) * (in_stride - 1);
+
+ const int output_depth = 5;
+ const int output_rows = input_rows - patch_rows_eff + 1;
+ const int output_cols = input_cols - patch_cols_eff + 1;
+
+ Tensor<float, 4> output_backward(output_depth, output_rows, output_cols,
+ num_batches);
+ output_backward.setRandom();
+
+ Tensor<float, 4> input(input_depth, input_rows, input_cols, num_batches);
+ input.setRandom();
+
+ const array<DenseIndex, 4> kernel_strides({1, 1, in_stride, in_stride});
+
+ const Tensor<float, 4> kernel_backward = SpatialConvolutionBackwardKernel(
+ input, output_backward, patch_rows, patch_cols, 1, in_stride);
+ const Tensor<float, 4> expected_kernel_backward =
+ SpatialConvolutionBackwardKernel(input, output_backward, patch_rows_eff,
+ patch_cols_eff)
+ .stride(kernel_strides);
+
+ EXPECT_EQ(kernel_backward.dimension(0), output_depth);
+ EXPECT_EQ(kernel_backward.dimension(1), input_depth);
+ EXPECT_EQ(kernel_backward.dimension(2), patch_rows);
+ EXPECT_EQ(kernel_backward.dimension(3), patch_cols);
+
+ eigen_assert(dimensions_match(kernel_backward.dimensions(),
+ expected_kernel_backward.dimensions()));
+ for (size_t i = 0; i < kernel_backward.dimensions().TotalSize(); ++i) {
+ EigenApprox(kernel_backward.data()[i], expected_kernel_backward.data()[i]);
+ }
+}
+
+TEST(EigenBackwardSpatialConvolutionsTest,
+ test_batched_atrous_spatial_convolution_backward_kernel_valid_row_major) {
+ const int num_batches = 11;
+ const int patch_rows = 3;
+ const int patch_cols = 3;
+
+ const int input_depth = 2;
+ const int input_rows = 9;
+ const int input_cols = 13;
+
+ const int in_stride = 3;
+ const int patch_rows_eff = patch_rows + (patch_rows - 1) * (in_stride - 1);
+ const int patch_cols_eff = patch_cols + (patch_cols - 1) * (in_stride - 1);
+
+ const int output_depth = 5;
+ const int output_rows = input_rows - patch_rows_eff + 1;
+ const int output_cols = input_cols - patch_cols_eff + 1;
+
+ Tensor<float, 4, RowMajor> output_backward(num_batches, output_cols,
+ output_rows, output_depth);
+ output_backward.setRandom();
+
+ Tensor<float, 4, RowMajor> input(num_batches, input_cols, input_rows,
+ input_depth);
+ input.setRandom();
+
+ const array<DenseIndex, 4> kernel_strides({in_stride, in_stride, 1, 1});
+
+ const Tensor<float, 4, RowMajor> kernel_backward =
+ SpatialConvolutionBackwardKernel(input, output_backward, patch_rows,
+ patch_cols, 1, in_stride);
+ const Tensor<float, 4, RowMajor> expected_kernel_backward =
+ SpatialConvolutionBackwardKernel(input, output_backward, patch_rows_eff,
+ patch_cols_eff)
+ .stride(kernel_strides);
+
+ EXPECT_EQ(kernel_backward.dimension(0), patch_cols);
+ EXPECT_EQ(kernel_backward.dimension(1), patch_rows);
+ EXPECT_EQ(kernel_backward.dimension(2), input_depth);
+ EXPECT_EQ(kernel_backward.dimension(3), output_depth);
+
+ eigen_assert(dimensions_match(kernel_backward.dimensions(),
+ expected_kernel_backward.dimensions()));
+ for (size_t i = 0; i < kernel_backward.dimensions().TotalSize(); ++i) {
+ EigenApprox(kernel_backward.data()[i], expected_kernel_backward.data()[i]);
+ }
+}
+
+TEST(EigenBackwardSpatialConvolutionsTest,
+ test_simple_cuboid_convolution_backward_kernel_valid) {
+ const int input_depth = 2;
+ const int input_planes = 5;
+ const int input_rows = 3;
+ const int input_cols = 4;
+ const int output_depth = 5;
+ const int patch_rows = 2;
+ const int patch_cols = 2;
+ const int patch_planes = 3;
+ const int output_rows = input_rows - patch_rows + 1;
+ const int output_cols = input_cols - patch_cols + 1;
+ const int output_planes = input_planes - patch_planes + 1;
+
+ Tensor<float, 4> input(input_depth, input_planes, input_rows, input_cols);
+ Tensor<float, 5> kernel(output_depth, input_depth, patch_planes, patch_rows,
+ patch_cols);
+ Tensor<float, 4> output_backward(output_depth, output_planes, output_rows,
+ output_cols);
+
+ output_backward = output_backward.constant(11.0f) + output_backward.random();
+ input = input.constant(2.0f) + input.random();
+ kernel.setRandom();
+
+ kernel = CuboidConvolutionBackwardKernel(input, output_backward, patch_planes,
+ patch_rows, patch_cols, 1, 1, 1);
+
+ EXPECT_EQ(kernel.dimension(0), output_depth);
+ EXPECT_EQ(kernel.dimension(1), input_depth);
+ EXPECT_EQ(kernel.dimension(2), patch_planes);
+ EXPECT_EQ(kernel.dimension(3), patch_rows);
+ EXPECT_EQ(kernel.dimension(4), patch_cols);
+
+ for (int od = 0; od < output_depth; ++od) {
+ for (int id = 0; id < input_depth; ++id) {
+ for (int p = 0; p < patch_planes; ++p) {
+ for (int r = 0; r < patch_rows; ++r) {
+ for (int c = 0; c < patch_cols; ++c) {
+ float expected = 0.0f;
+ for (int i = 0; i < input_planes; ++i) {
+ for (int j = 0; j < input_rows; ++j) {
+ for (int k = 0; k < input_cols; ++k) {
+ int output_j = j - r;
+ int output_k = k - c;
+ int output_i = i - p;
+ if (output_i >= 0 && output_i < output_planes &&
+ output_j >= 0 && output_j < output_rows &&
+ output_k >= 0 && output_k < output_cols) {
+ expected +=
+ input(id, i, j, k) *
+ output_backward(od, output_i, output_j, output_k);
+ }
+ }
+ }
+ }
+ EigenApprox(kernel(od, id, p, r, c), expected);
+ }
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenBackwardSpatialConvolutionsTest,
+ test_simple_cuboid_convolution_backward_kernel_valid_row_major) {
+ const int input_depth = 2;
+ const int input_planes = 5;
+ const int input_rows = 3;
+ const int input_cols = 4;
+ const int output_depth = 5;
+ const int patch_rows = 2;
+ const int patch_cols = 2;
+ const int patch_planes = 3;
+ const int output_rows = input_rows - patch_rows + 1;
+ const int output_cols = input_cols - patch_cols + 1;
+ const int output_planes = input_planes - patch_planes + 1;
+
+ Tensor<float, 4, RowMajor> input(input_cols, input_rows, input_planes,
+ input_depth);
+ Tensor<float, 5, RowMajor> kernel(patch_cols, patch_rows, patch_planes,
+ input_depth, output_depth);
+ Tensor<float, 4, RowMajor> output_backward(output_cols, output_rows,
+ output_planes, output_depth);
+
+ output_backward = output_backward.constant(11.0f) + output_backward.random();
+ input = input.constant(2.0f) + input.random();
+ kernel.setRandom();
+
+ kernel = CuboidConvolutionBackwardKernel(input, output_backward, patch_planes,
+ patch_rows, patch_cols, 1, 1, 1);
+
+ EXPECT_EQ(kernel.dimension(4), output_depth);
+ EXPECT_EQ(kernel.dimension(3), input_depth);
+ EXPECT_EQ(kernel.dimension(2), patch_planes);
+ EXPECT_EQ(kernel.dimension(1), patch_rows);
+ EXPECT_EQ(kernel.dimension(0), patch_cols);
+
+ for (int od = 0; od < output_depth; ++od) {
+ for (int id = 0; id < input_depth; ++id) {
+ for (int p = 0; p < patch_planes; ++p) {
+ for (int r = 0; r < patch_rows; ++r) {
+ for (int c = 0; c < patch_cols; ++c) {
+ float expected = 0.0f;
+ for (int i = 0; i < input_planes; ++i) {
+ for (int j = 0; j < input_rows; ++j) {
+ for (int k = 0; k < input_cols; ++k) {
+ int output_j = j - r;
+ int output_k = k - c;
+ int output_i = i - p;
+ if (output_i >= 0 && output_i < output_planes &&
+ output_j >= 0 && output_j < output_rows &&
+ output_k >= 0 && output_k < output_cols) {
+ expected +=
+ input(k, j, i, id) *
+ output_backward(output_k, output_j, output_i, od);
+ }
+ }
+ }
+ }
+ EigenApprox(kernel(c, r, p, id, od), expected);
+ }
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenBackwardSpatialConvolutionsTest,
+ test_batched_spatial_convolution_backward_kernel_valid) {
+ const int num_batches = 13;
+ const int input_depth = 2;
+ const int input_rows = 7;
+ const int input_cols = 9;
+ const int output_depth = 3;
+ const int patch_rows = 5;
+ const int patch_cols = 5;
+ const int output_rows = input_rows - patch_rows + 1;
+ const int output_cols = input_cols - patch_cols + 1;
+
+ Tensor<float, 4> input(input_depth, input_rows, input_cols, num_batches);
+ Tensor<float, 4> kernel_backward(output_depth, input_depth, patch_rows,
+ patch_cols);
+ Tensor<float, 4> output_backward(output_depth, output_rows, output_cols,
+ num_batches);
+
+ output_backward = output_backward.constant(11.0f) + output_backward.random();
+ input = input.constant(2.0f) + input.random();
+ kernel_backward.setRandom();
+
+ kernel_backward = SpatialConvolutionBackwardKernel(input, output_backward,
+ patch_rows, patch_cols, 1);
+
+ EXPECT_EQ(kernel_backward.dimension(0), output_depth);
+ EXPECT_EQ(kernel_backward.dimension(1), input_depth);
+ EXPECT_EQ(kernel_backward.dimension(2), patch_rows);
+ EXPECT_EQ(kernel_backward.dimension(3), patch_cols);
+
+ for (int od = 0; od < output_depth; ++od) {
+ for (int id = 0; id < input_depth; ++id) {
+ for (int c = 0; c < patch_cols; ++c) {
+ for (int r = 0; r < patch_rows; ++r) {
+ float expected = 0.0f;
+ for (int b = 0; b < num_batches; ++b) {
+ for (int i = 0; i < input_rows; ++i) {
+ for (int j = 0; j < input_cols; ++j) {
+ int output_i = i - r;
+ int output_j = j - c;
+ if (output_i >= 0 && output_i < output_rows && output_j >= 0 &&
+ output_j < output_cols) {
+ expected += input(id, i, j, b) *
+ output_backward(od, output_i, output_j, b);
+ }
+ }
+ }
+ }
+ EigenApprox(kernel_backward(od, id, r, c), expected);
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenBackwardSpatialConvolutionsTest,
+ test_batched_spatial_convolution_backward_kernel_valid_row_major) {
+ const int num_batches = 13;
+ const int input_depth = 2;
+ const int input_rows = 7;
+ const int input_cols = 9;
+ const int output_depth = 3;
+ const int patch_rows = 4;
+ const int patch_cols = 4;
+ const int output_rows = input_rows - patch_rows + 1;
+ const int output_cols = input_cols - patch_cols + 1;
+
+ Tensor<float, 4, RowMajor> input(num_batches, input_cols, input_rows,
+ input_depth);
+ Tensor<float, 4, RowMajor> kernel_backward(patch_cols, patch_rows,
+ input_depth, output_depth);
+ Tensor<float, 4, RowMajor> output_backward(num_batches, output_cols,
+ output_rows, output_depth);
+
+ output_backward = output_backward.constant(11.0f) + output_backward.random();
+ input = input.constant(2.0f) + input.random();
+ kernel_backward.setRandom();
+
+ kernel_backward = SpatialConvolutionBackwardKernel(input, output_backward,
+ patch_rows, patch_cols, 1);
+
+ EXPECT_EQ(kernel_backward.dimension(0), patch_cols);
+ EXPECT_EQ(kernel_backward.dimension(1), patch_rows);
+ EXPECT_EQ(kernel_backward.dimension(2), input_depth);
+ EXPECT_EQ(kernel_backward.dimension(3), output_depth);
+
+ for (int od = 0; od < output_depth; ++od) {
+ for (int id = 0; id < input_depth; ++id) {
+ for (int c = 0; c < patch_cols; ++c) {
+ for (int r = 0; r < patch_rows; ++r) {
+ float expected = 0.0f;
+ for (int b = 0; b < num_batches; ++b) {
+ for (int i = 0; i < input_rows; ++i) {
+ for (int j = 0; j < input_cols; ++j) {
+ int output_i = i - r;
+ int output_j = j - c;
+ if (output_i >= 0 && output_i < output_rows && output_j >= 0 &&
+ output_j < output_cols) {
+ expected += input(b, j, i, id) *
+ output_backward(b, output_j, output_i, od);
+ }
+ }
+ }
+ }
+ EigenApprox(kernel_backward(c, r, id, od), expected);
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenBackwardSpatialConvolutionsTest,
+ test_batched_cuboid_convolution_backward_kernel_valid) {
+ const int num_batches = 13;
+ const int input_depth = 2;
+ const int input_planes = 5;
+ const int input_rows = 7;
+ const int input_cols = 9;
+ const int output_depth = 3;
+ const int patch_rows = 5;
+ const int patch_cols = 5;
+ const int patch_planes = 3;
+ const int output_rows = input_rows - patch_rows + 1;
+ const int output_cols = input_cols - patch_cols + 1;
+ const int output_planes = input_planes - patch_planes + 1;
+
+ Tensor<float, 5> input(input_depth, input_planes, input_rows, input_cols,
+ num_batches);
+ Tensor<float, 5> kernel_backward(output_depth, input_depth, patch_planes,
+ patch_rows, patch_cols);
+ Tensor<float, 5> output_backward(output_depth, output_planes, output_rows,
+ output_cols, num_batches);
+
+ output_backward = output_backward.constant(11.0f) + output_backward.random();
+ input = input.constant(2.0f) + input.random();
+ kernel_backward.setRandom();
+
+ kernel_backward = CuboidConvolutionBackwardKernel(
+ input, output_backward, patch_planes, patch_rows, patch_cols, 1, 1, 1);
+
+ EXPECT_EQ(kernel_backward.dimension(0), output_depth);
+ EXPECT_EQ(kernel_backward.dimension(1), input_depth);
+ EXPECT_EQ(kernel_backward.dimension(2), patch_planes);
+ EXPECT_EQ(kernel_backward.dimension(3), patch_rows);
+ EXPECT_EQ(kernel_backward.dimension(4), patch_cols);
+
+ for (int od = 0; od < output_depth; ++od) {
+ for (int id = 0; id < input_depth; ++id) {
+ for (int p = 0; p < patch_planes; ++p) {
+ for (int c = 0; c < patch_cols; ++c) {
+ for (int r = 0; r < patch_rows; ++r) {
+ float expected = 0.0f;
+ for (int b = 0; b < num_batches; ++b) {
+ for (int i = 0; i < input_planes; ++i) {
+ for (int j = 0; j < input_rows; ++j) {
+ for (int k = 0; k < input_cols; ++k) {
+ int output_j = j - r;
+ int output_k = k - c;
+ int output_i = i - p;
+ if (output_i >= 0 && output_i < output_planes &&
+ output_j >= 0 && output_j < output_rows &&
+ output_k >= 0 && output_k < output_cols) {
+ expected +=
+ input(id, i, j, k, b) *
+ output_backward(od, output_i, output_j, output_k, b);
+ }
+ }
+ }
+ }
+ }
+ EigenApprox(kernel_backward(od, id, p, r, c), expected);
+ }
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenBackwardSpatialConvolutionsTest,
+ test_batched_cuboid_convolution_backward_kernel_valid_row_major) {
+ const int num_batches = 13;
+ const int input_depth = 2;
+ const int input_planes = 5;
+ const int input_rows = 7;
+ const int input_cols = 9;
+ const int output_depth = 3;
+ const int patch_rows = 5;
+ const int patch_cols = 5;
+ const int patch_planes = 3;
+ const int output_rows = input_rows - patch_rows + 1;
+ const int output_cols = input_cols - patch_cols + 1;
+ const int output_planes = input_planes - patch_planes + 1;
+
+ Tensor<float, 5, RowMajor> input(num_batches, input_cols, input_rows,
+ input_planes, input_depth);
+ Tensor<float, 5, RowMajor> kernel_backward(
+ patch_cols, patch_rows, patch_planes, input_depth, output_depth);
+ Tensor<float, 5, RowMajor> output_backward(
+ num_batches, output_cols, output_rows, output_planes, output_depth);
+
+ output_backward = output_backward.constant(11.0f) + output_backward.random();
+ input = input.constant(2.0f) + input.random();
+ kernel_backward.setRandom();
+
+ kernel_backward = CuboidConvolutionBackwardKernel(
+ input, output_backward, patch_planes, patch_rows, patch_cols, 1, 1, 1);
+
+ EXPECT_EQ(kernel_backward.dimension(4), output_depth);
+ EXPECT_EQ(kernel_backward.dimension(3), input_depth);
+ EXPECT_EQ(kernel_backward.dimension(2), patch_planes);
+ EXPECT_EQ(kernel_backward.dimension(1), patch_rows);
+ EXPECT_EQ(kernel_backward.dimension(0), patch_cols);
+
+ for (int od = 0; od < output_depth; ++od) {
+ for (int id = 0; id < input_depth; ++id) {
+ for (int p = 0; p < patch_planes; ++p) {
+ for (int c = 0; c < patch_cols; ++c) {
+ for (int r = 0; r < patch_rows; ++r) {
+ float expected = 0.0f;
+ for (int b = 0; b < num_batches; ++b) {
+ for (int i = 0; i < input_planes; ++i) {
+ for (int j = 0; j < input_rows; ++j) {
+ for (int k = 0; k < input_cols; ++k) {
+ int output_j = j - r;
+ int output_k = k - c;
+ int output_i = i - p;
+ if (output_i >= 0 && output_i < output_planes &&
+ output_j >= 0 && output_j < output_rows &&
+ output_k >= 0 && output_k < output_cols) {
+ expected +=
+ input(b, k, j, i, id) *
+ output_backward(b, output_k, output_j, output_i, od);
+ }
+ }
+ }
+ }
+ }
+ EigenApprox(kernel_backward(c, r, p, id, od), expected);
+ }
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenBackwardSpatialConvolutionsTest,
+ test_batched_strided_spatial_convolution_backward_kernel_valid) {
+ const int num_batches = 13;
+ const int input_depth = 2;
+ const int input_rows = 7;
+ const int input_cols = 9;
+ const int output_depth = 3;
+ const int patch_rows = 5;
+ const int patch_cols = 5;
+
+ const int stride = 2;
+
+ const int output_rows = (input_rows - patch_rows + 1 + stride - 1) / stride;
+ const int output_cols = (input_cols - patch_cols + 1 + stride - 1) / stride;
+
+ Tensor<float, 4> input(input_depth, input_rows, input_cols, num_batches);
+ Tensor<float, 4> kernel_backward(output_depth, input_depth, patch_rows,
+ patch_cols);
+ Tensor<float, 4> output_backward(output_depth, output_rows, output_cols,
+ num_batches);
+
+ output_backward = output_backward.constant(11.0f) + output_backward.random();
+ input = input.constant(2.0f) + input.random();
+ kernel_backward.setRandom();
+
+ kernel_backward = SpatialConvolutionBackwardKernel(
+ input, output_backward, patch_rows, patch_cols, stride);
+
+ EXPECT_EQ(kernel_backward.dimension(0), output_depth);
+ EXPECT_EQ(kernel_backward.dimension(1), input_depth);
+ EXPECT_EQ(kernel_backward.dimension(2), patch_rows);
+ EXPECT_EQ(kernel_backward.dimension(3), patch_cols);
+
+ for (int od = 0; od < output_depth; ++od) {
+ for (int id = 0; id < input_depth; ++id) {
+ for (int c = 0; c < patch_cols; ++c) {
+ for (int r = 0; r < patch_rows; ++r) {
+ float expected = 0.0f;
+ for (int b = 0; b < num_batches; ++b) {
+ for (int i = 0; i < input_rows; ++i) {
+ for (int j = 0; j < input_cols; ++j) {
+ int output_i = i - r;
+ int output_j = j - c;
+ if (output_i >= 0 && output_i / stride < output_rows &&
+ output_j >= 0 && output_j / stride < output_cols &&
+ output_i % stride == 0 && output_j % stride == 0) {
+ expected += input(id, i, j, b) *
+ output_backward(od, output_i / stride,
+ output_j / stride, b);
+ }
+ }
+ }
+ }
+ EigenApprox(kernel_backward(od, id, r, c), expected);
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenBackwardSpatialConvolutionsTest,
+ test_batched_strided_spatial_convolution_backward_kernel_valid_row_major) {
+ const int num_batches = 13;
+ const int input_depth = 2;
+ const int input_rows = 7;
+ const int input_cols = 9;
+ const int output_depth = 3;
+ const int patch_rows = 4;
+ const int patch_cols = 4;
+
+ const int stride = 2;
+
+ const int output_rows = (input_rows - patch_rows + 1 + stride - 1) / stride;
+ const int output_cols = (input_cols - patch_cols + 1 + stride - 1) / stride;
+
+ Tensor<float, 4, RowMajor> input(num_batches, input_cols, input_rows,
+ input_depth);
+ Tensor<float, 4, RowMajor> kernel_backward(patch_cols, patch_rows,
+ input_depth, output_depth);
+ Tensor<float, 4, RowMajor> output_backward(num_batches, output_cols,
+ output_rows, output_depth);
+
+ output_backward = output_backward.constant(11.0f) + output_backward.random();
+ input = input.constant(2.0f) + input.random();
+ kernel_backward.setRandom();
+
+ kernel_backward = SpatialConvolutionBackwardKernel(
+ input, output_backward, patch_rows, patch_cols, stride);
+
+ EXPECT_EQ(kernel_backward.dimension(0), patch_cols);
+ EXPECT_EQ(kernel_backward.dimension(1), patch_rows);
+ EXPECT_EQ(kernel_backward.dimension(2), input_depth);
+ EXPECT_EQ(kernel_backward.dimension(3), output_depth);
+
+ for (int od = 0; od < output_depth; ++od) {
+ for (int id = 0; id < input_depth; ++id) {
+ for (int c = 0; c < patch_cols; ++c) {
+ for (int r = 0; r < patch_rows; ++r) {
+ float expected = 0.0f;
+ for (int b = 0; b < num_batches; ++b) {
+ for (int i = 0; i < input_rows; ++i) {
+ for (int j = 0; j < input_cols; ++j) {
+ int output_i = i - r;
+ int output_j = j - c;
+ if (output_i >= 0 && output_i / stride < output_rows &&
+ output_j >= 0 && output_j / stride < output_cols &&
+ output_i % stride == 0 && output_j % stride == 0) {
+ expected += input(b, j, i, id) *
+ output_backward(b, output_j / stride,
+ output_i / stride, od);
+ }
+ }
+ }
+ }
+ EigenApprox(kernel_backward(c, r, id, od), expected);
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenBackwardSpatialConvolutionsTest,
+ test_batched_strided_cuboid_convolution_backward_kernel_valid) {
+ const int num_batches = 13;
+ const int input_depth = 2;
+ const int input_planes = 8;
+ const int input_rows = 7;
+ const int input_cols = 9;
+ const int output_depth = 3;
+ const int patch_planes = 3;
+ const int patch_rows = 3;
+ const int patch_cols = 2;
+
+ const int stride_planes = 2;
+ const int stride_cols = 3;
+ const int stride_rows = 1;
+
+ const int output_rows = ceil_div(input_rows - patch_rows + 1, stride_rows);
+ const int output_cols = ceil_div(input_cols - patch_cols + 1, stride_cols);
+ const int output_planes =
+ ceil_div(input_planes - patch_planes + 1, stride_planes);
+
+ Tensor<float, 5> input(input_depth, input_planes, input_rows, input_cols,
+ num_batches);
+ Tensor<float, 5> kernel_backward(output_depth, input_depth, patch_planes,
+ patch_rows, patch_cols);
+ Tensor<float, 5> output_backward(output_depth, output_planes, output_rows,
+ output_cols, num_batches);
+
+ output_backward = output_backward.constant(11.0f) + output_backward.random();
+ input = input.constant(2.0f) + input.random();
+ kernel_backward.setRandom();
+
+ kernel_backward = CuboidConvolutionBackwardKernel(
+ input, output_backward, patch_planes, patch_rows, patch_cols,
+ stride_planes, stride_rows, stride_cols);
+
+ EXPECT_EQ(kernel_backward.dimension(0), output_depth);
+ EXPECT_EQ(kernel_backward.dimension(1), input_depth);
+ EXPECT_EQ(kernel_backward.dimension(2), patch_planes);
+ EXPECT_EQ(kernel_backward.dimension(3), patch_rows);
+ EXPECT_EQ(kernel_backward.dimension(4), patch_cols);
+
+ for (int od = 0; od < output_depth; ++od) {
+ for (int id = 0; id < input_depth; ++id) {
+ for (int p = 0; p < patch_planes; ++p) {
+ for (int c = 0; c < patch_cols; ++c) {
+ for (int r = 0; r < patch_rows; ++r) {
+ float expected = 0.0f;
+ for (int b = 0; b < num_batches; ++b) {
+ for (int i = 0; i < input_planes; ++i) {
+ for (int j = 0; j < input_rows; ++j) {
+ for (int k = 0; k < input_cols; ++k) {
+ int output_j = j - r;
+ int output_k = k - c;
+ int output_i = i - p;
+ if (output_i >= 0 &&
+ output_i / stride_planes < output_planes &&
+ output_j >= 0 && output_j / stride_rows < output_rows &&
+ output_k >= 0 && output_k / stride_cols < output_cols &&
+ output_i % stride_planes == 0 &&
+ output_j % stride_rows == 0 &&
+ output_k % stride_cols == 0) {
+ expected += input(id, i, j, k, b) *
+ output_backward(od, output_i / stride_planes,
+ output_j / stride_rows,
+ output_k / stride_cols, b);
+ }
+ }
+ }
+ }
+ }
+ EigenApprox(kernel_backward(od, id, p, r, c), expected);
+ }
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenBackwardSpatialConvolutionsTest,
+ test_batched_strided_cuboid_convolution_backward_kernel_valid_row_major) {
+ const int num_batches = 13;
+ const int input_depth = 2;
+ const int input_planes = 8;
+ const int input_rows = 7;
+ const int input_cols = 9;
+ const int output_depth = 3;
+ const int patch_planes = 3;
+ const int patch_rows = 3;
+ const int patch_cols = 2;
+
+ const int stride_planes = 2;
+ const int stride_cols = 3;
+ const int stride_rows = 1;
+
+ const int output_rows = ceil_div(input_rows - patch_rows + 1, stride_rows);
+ const int output_cols = ceil_div(input_cols - patch_cols + 1, stride_cols);
+ const int output_planes =
+ ceil_div(input_planes - patch_planes + 1, stride_planes);
+
+ Tensor<float, 5, RowMajor> input(num_batches, input_cols, input_rows,
+ input_planes, input_depth);
+ Tensor<float, 5, RowMajor> kernel_backward(
+ patch_cols, patch_rows, patch_planes, input_depth, output_depth);
+ Tensor<float, 5, RowMajor> output_backward(
+ num_batches, output_cols, output_rows, output_planes, output_depth);
+
+ output_backward = output_backward.constant(11.0f) + output_backward.random();
+ input = input.constant(2.0f) + input.random();
+ kernel_backward.setRandom();
+
+ kernel_backward = CuboidConvolutionBackwardKernel(
+ input, output_backward, patch_planes, patch_rows, patch_cols,
+ stride_planes, stride_rows, stride_cols);
+
+ EXPECT_EQ(kernel_backward.dimension(4), output_depth);
+ EXPECT_EQ(kernel_backward.dimension(3), input_depth);
+ EXPECT_EQ(kernel_backward.dimension(2), patch_planes);
+ EXPECT_EQ(kernel_backward.dimension(1), patch_rows);
+ EXPECT_EQ(kernel_backward.dimension(0), patch_cols);
+
+ for (int od = 0; od < output_depth; ++od) {
+ for (int id = 0; id < input_depth; ++id) {
+ for (int p = 0; p < patch_planes; ++p) {
+ for (int c = 0; c < patch_cols; ++c) {
+ for (int r = 0; r < patch_rows; ++r) {
+ float expected = 0.0f;
+ for (int b = 0; b < num_batches; ++b) {
+ for (int i = 0; i < input_planes; ++i) {
+ for (int j = 0; j < input_rows; ++j) {
+ for (int k = 0; k < input_cols; ++k) {
+ int output_j = j - r;
+ int output_k = k - c;
+ int output_i = i - p;
+ if (output_i >= 0 &&
+ output_i / stride_planes < output_planes &&
+ output_j >= 0 && output_j / stride_rows < output_rows &&
+ output_k >= 0 && output_k / stride_cols < output_cols &&
+ output_i % stride_planes == 0 &&
+ output_j % stride_rows == 0 &&
+ output_k % stride_cols == 0) {
+ expected += input(b, k, j, i, id) *
+ output_backward(b, output_k / stride_cols,
+ output_j / stride_rows,
+ output_i / stride_planes, od);
+ }
+ }
+ }
+ }
+ }
+ EigenApprox(kernel_backward(c, r, p, id, od), expected);
+ }
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenBackwardSpatialConvolutionsTest,
+ test_batched_strided_cuboid_convolution_backward_input_valid) {
+ const int num_batches = 13;
+ const int input_depth = 2;
+ const int input_planes = 14;
+ const int input_rows = 13;
+ const int input_cols = 15;
+ const int patch_rows = 3;
+ const int patch_cols = 2;
+ const int patch_planes = 4;
+ const int stride_rows = 3;
+ const int stride_cols = 2;
+ const int stride_planes = 3;
+ const int output_rows = ceil_div(input_rows - patch_rows + 1, stride_rows);
+ const int output_cols = ceil_div(input_cols - patch_cols + 1, stride_cols);
+ const int output_planes =
+ ceil_div(input_planes - patch_planes + 1, stride_planes);
+ const int output_depth = 5;
+
+ Tensor<float, 5> input_backward(input_depth, input_planes, input_rows,
+ input_cols, num_batches);
+ Tensor<float, 5> kernel(output_depth, input_depth, patch_planes, patch_rows,
+ patch_cols);
+ Tensor<float, 5> output_backward(output_depth, output_planes, output_rows,
+ output_cols, num_batches);
+
+ output_backward = output_backward.constant(11.0f) + output_backward.random();
+ kernel = kernel.constant(2.0f) + kernel.random();
+ input_backward.setRandom();
+
+ input_backward = CuboidConvolutionBackwardInput(
+ kernel, output_backward, input_planes, input_rows, input_cols,
+ stride_planes, stride_rows, stride_cols);
+
+ EXPECT_EQ(input_backward.dimension(4), num_batches);
+ EXPECT_EQ(input_backward.dimension(3), input_cols);
+ EXPECT_EQ(input_backward.dimension(2), input_rows);
+ EXPECT_EQ(input_backward.dimension(1), input_planes);
+ EXPECT_EQ(input_backward.dimension(0), input_depth);
+
+ for (int b = 0; b < num_batches; ++b) {
+ for (int id = 0; id < input_depth; ++id) {
+ for (int i = 0; i < input_planes; ++i) {
+ for (int j = 0; j < input_rows; ++j) {
+ for (int k = 0; k < input_cols; ++k) {
+ float expected = 0.0f;
+ for (int c = 0; c < patch_cols; ++c) {
+ for (int r = 0; r < patch_rows; ++r) {
+ for (int p = 0; p < patch_planes; ++p) {
+ for (int od = 0; od < output_depth; ++od) {
+ int output_j = j - r;
+ int output_k = k - c;
+ int output_i = i - p;
+ if (output_i >= 0 &&
+ output_i / stride_planes < output_planes &&
+ output_j >= 0 && output_j / stride_rows < output_rows &&
+ output_k >= 0 && output_k / stride_cols < output_cols &&
+ output_i % stride_planes == 0 &&
+ output_j % stride_rows == 0 &&
+ output_k % stride_cols == 0) {
+ expected += output_backward(od, output_i / stride_planes,
+ output_j / stride_rows,
+ output_k / stride_cols, b) *
+ kernel(od, id, p, r, c);
+ }
+ }
+ }
+ }
+ }
+ EigenApprox(input_backward(id, i, j, k, b), expected);
+ }
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenBackwardSpatialConvolutionsTest,
+ test_batched_strided_cuboid_convolution_backward_input_valid_row_major) {
+ const int num_batches = 13;
+ const int input_depth = 2;
+ const int input_planes = 14;
+ const int input_rows = 13;
+ const int input_cols = 15;
+ const int patch_rows = 3;
+ const int patch_cols = 2;
+ const int patch_planes = 4;
+ const int stride_rows = 3;
+ const int stride_cols = 2;
+ const int stride_planes = 3;
+ const int output_rows = ceil_div(input_rows - patch_rows + 1, stride_rows);
+ const int output_cols = ceil_div(input_cols - patch_cols + 1, stride_cols);
+ const int output_planes =
+ ceil_div(input_planes - patch_planes + 1, stride_planes);
+ const int output_depth = 5;
+
+ Tensor<float, 5, RowMajor> input_backward(num_batches, input_cols, input_rows,
+ input_planes, input_depth);
+ Tensor<float, 5, RowMajor> kernel(patch_cols, patch_rows, patch_planes,
+ input_depth, output_depth);
+ Tensor<float, 5, RowMajor> output_backward(
+ num_batches, output_cols, output_rows, output_planes, output_depth);
+
+ output_backward = output_backward.constant(11.0f) + output_backward.random();
+ kernel = kernel.constant(2.0f) + kernel.random();
+ input_backward.setRandom();
+
+ input_backward = CuboidConvolutionBackwardInput(
+ kernel, output_backward, input_planes, input_rows, input_cols,
+ stride_planes, stride_rows, stride_cols);
+
+ EXPECT_EQ(input_backward.dimension(0), num_batches);
+ EXPECT_EQ(input_backward.dimension(1), input_cols);
+ EXPECT_EQ(input_backward.dimension(2), input_rows);
+ EXPECT_EQ(input_backward.dimension(3), input_planes);
+ EXPECT_EQ(input_backward.dimension(4), input_depth);
+
+ for (int b = 0; b < num_batches; ++b) {
+ for (int id = 0; id < input_depth; ++id) {
+ for (int i = 0; i < input_planes; ++i) {
+ for (int j = 0; j < input_rows; ++j) {
+ for (int k = 0; k < input_cols; ++k) {
+ float expected = 0.0f;
+ for (int c = 0; c < patch_cols; ++c) {
+ for (int r = 0; r < patch_rows; ++r) {
+ for (int p = 0; p < patch_planes; ++p) {
+ for (int od = 0; od < output_depth; ++od) {
+ int output_j = j - r;
+ int output_k = k - c;
+ int output_i = i - p;
+ if (output_i >= 0 &&
+ output_i / stride_planes < output_planes &&
+ output_j >= 0 && output_j / stride_rows < output_rows &&
+ output_k >= 0 && output_k / stride_cols < output_cols &&
+ output_i % stride_planes == 0 &&
+ output_j % stride_rows == 0 &&
+ output_k % stride_cols == 0) {
+ expected +=
+ output_backward(b, output_k / stride_cols,
+ output_j / stride_rows,
+ output_i / stride_planes, od) *
+ kernel(c, r, p, id, od);
+ }
+ }
+ }
+ }
+ }
+ EigenApprox(input_backward(b, k, j, i, id), expected);
+ }
+ }
+ }
+ }
+ }
+}
+
+} // namespace Eigen
diff --git a/tensorflow/core/kernels/eigen_cuboid_convolution.h b/tensorflow/core/kernels/eigen_cuboid_convolution.h
new file mode 100644
index 0000000000..ed4c3fca1a
--- /dev/null
+++ b/tensorflow/core/kernels/eigen_cuboid_convolution.h
@@ -0,0 +1,195 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_CUBOID_CONVOLUTION_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_CUBOID_CONVOLUTION_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/kernels/eigen_patch_3d.h"
+
+namespace Eigen {
+
+/** CuboidConvolution
+ * \ingroup CXX11_NeuralNetworks_Module
+ *
+ * \brief Applies a 3D convolution over a multichannel input voxel block.
+ *
+ * The input parameter is expected to be a tensor with a rank of 4 or more (channels, depth, height, width, and optionally others).
+ * The kernel parameter is expected to be a 5D tensor (filters, channels, kernel_depth, kernel_height, kernel_width).
+ * The result can be assigned to a tensor of rank equal to the rank of the input. The dimensions of the result will be filters, depth, height, width (and others if applicable).
+ *
+ * The input and kernel have to be in the same layout, and both row-major and
+ * col-major are supported. The shapes given above are for col-major layout.
+ * For row-major, all dimensions should be reversed.
+ *
+ * It is possible to swap the order of the depth, width, and height dimensions provided that the same order is used in the input, the kernel, and the output.
+ */
+template <typename Input, typename Kernel>
+EIGEN_ALWAYS_INLINE
+static const typename internal::conditional <
+ internal::traits<Input>::Layout == ColMajor,
+ TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const TensorContractionOp<
+ const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index, 2>,
+ const Kernel>,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index, 2>,
+ const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
+ const Input> > > >,
+ TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const TensorContractionOp<
+ const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index, 2>,
+ const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
+ const Input> > ,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index, 2>,
+ const Kernel> > > >::type
+CuboidConvolution(const Input& input, const Kernel& kernel,
+ const DenseIndex stridePlanes = 1,
+ const DenseIndex strideRows = 1,
+ const DenseIndex strideCols = 1,
+ const PaddingType padding_type = PADDING_SAME) {
+ typedef typename internal::traits<Input>::Index TensorIndex;
+ TensorRef<Tensor<typename internal::traits<Input>::Scalar, internal::traits<Input>::NumDimensions, internal::traits<Input>::Layout, TensorIndex> > in(input);
+ TensorRef<Tensor<typename internal::traits<Kernel>::Scalar, internal::traits<Kernel>::NumDimensions, internal::traits<Kernel>::Layout, TensorIndex> > kern(kernel);
+
+ EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout == internal::traits<Kernel>::Layout, YOU_MADE_A_PROGRAMMING_MISTAKE);
+ static const bool isColMajor = (internal::traits<Input>::Layout == ColMajor);
+ static const int NumDims = internal::traits<Input>::NumDimensions;
+
+ // Number of filters to apply. This is the same as the output depth of the result.
+ const TensorIndex kernelFilters = isColMajor ? kern.dimensions()[0] : kern.dimensions()[4];
+ const TensorIndex kernelChannels = isColMajor ? kern.dimensions()[1] : kern.dimensions()[3];
+
+ // Spatial size of the kernel.
+ const TensorIndex kernelDepth = isColMajor ? kern.dimensions()[2] : kern.dimensions()[2];
+ const TensorIndex kernelRows = isColMajor ? kern.dimensions()[3] : kern.dimensions()[1];
+ const TensorIndex kernelCols = isColMajor ? kern.dimensions()[4] : kern.dimensions()[0];
+
+ if (isColMajor) {
+ eigen_assert(kernelChannels == in.dimension(0));
+ } else {
+ eigen_assert(kernelChannels == in.dimension(NumDims - 1));
+ }
+
+ const TensorIndex inputPlanes = isColMajor ? in.dimension(1) : in.dimension(NumDims - 2);
+ const TensorIndex inputRows = isColMajor ? in.dimension(2) : in.dimension(NumDims - 3);
+ const TensorIndex inputCols = isColMajor ? in.dimension(3) : in.dimension(NumDims - 4);
+
+ const float stride_planes_f = static_cast<float>(stridePlanes);
+ const float stride_rows_f = static_cast<float>(strideRows);
+ const float stride_cols_f = static_cast<float>(strideCols);
+ TensorIndex out_depth;
+ TensorIndex out_height;
+ TensorIndex out_width;
+ switch (padding_type) {
+ case PADDING_VALID:
+ out_depth = ceil((inputPlanes - kernelDepth + 1.f) / stride_planes_f);
+ out_height = ceil((inputRows - kernelRows + 1.f) / stride_rows_f);
+ out_width = ceil((inputCols - kernelCols + 1.f) / stride_cols_f);
+ break;
+ case PADDING_SAME:
+ out_depth = ceil(inputPlanes / stride_planes_f);
+ out_height = ceil(inputRows / stride_rows_f);
+ out_width = ceil(inputCols / stride_cols_f);
+ break;
+ default:
+ eigen_assert(false && "unexpected padding");
+ }
+
+ DSizes<TensorIndex, 2> kernel_dims;
+ if (isColMajor) {
+ kernel_dims[0] = kernelFilters;
+ kernel_dims[1] = kernelChannels * kernelDepth * kernelRows * kernelCols;
+ } else {
+ kernel_dims[0] = kernelChannels * kernelDepth * kernelRows * kernelCols;
+ kernel_dims[1] = kernelFilters;
+ }
+
+ // Molds the output of the patch extraction result into a 2D tensor:
+ // - the first dimension (dims[0]): the patch values to be multiplied with the kernels
+ // - the second dimension (dims[1]): everything else
+ DSizes<TensorIndex, 2> pre_contract_dims;
+ if (isColMajor) {
+ pre_contract_dims[0] = kernelChannels * kernelDepth * kernelRows * kernelCols;
+ pre_contract_dims[1] = out_depth * out_height * out_width;
+ for (int i = 4; i < NumDims; ++i) {
+ pre_contract_dims[1] *= in.dimension(i);
+ }
+ } else {
+ pre_contract_dims[1] = kernelChannels * kernelDepth * kernelRows * kernelCols;
+ pre_contract_dims[0] = out_depth * out_height * out_width;
+ for (int i = 0; i < NumDims - 4; ++i) {
+ pre_contract_dims[0] *= in.dimension(i);
+ }
+ }
+
+ array<IndexPair<TensorIndex>, 1> contract_dims;
+ contract_dims[0] = IndexPair<TensorIndex>(1, 0);
+
+ // Molds the output of the contraction into the shape expected by the user
+ // (assuming ColMajor):
+ // - 1st dim: kernel filters
+ // - 2nd dim: output depth
+ // - 3nd dim: output height
+ // - 4rd dim: output width
+ // - 5th dim and beyond: everything else including batch size
+ DSizes<TensorIndex, NumDims> post_contract_dims;
+ if (isColMajor) {
+ post_contract_dims[0] = kernelFilters;
+ post_contract_dims[1] = out_depth;
+ post_contract_dims[2] = out_height;
+ post_contract_dims[3] = out_width;
+ for (int i = 4; i < NumDims; ++i) {
+ post_contract_dims[i] = in.dimension(i);
+ }
+ } else {
+ post_contract_dims[NumDims - 1] = kernelFilters;
+ post_contract_dims[NumDims - 2] = out_depth;
+ post_contract_dims[NumDims - 3] = out_height;
+ post_contract_dims[NumDims - 4] = out_width;
+ for (int i = 0; i < NumDims - 4; ++i) {
+ post_contract_dims[i] = in.dimension(i);
+ }
+ }
+
+ return choose(
+ Cond<internal::traits<Input>::Layout == ColMajor>(),
+ kernel.reshape(kernel_dims)
+ .contract(input.extract_volume_patches(
+ kernelDepth, kernelRows, kernelCols, stridePlanes,
+ strideRows, strideCols, padding_type)
+ .reshape(pre_contract_dims),
+ contract_dims)
+ .reshape(post_contract_dims),
+ input.extract_volume_patches(kernelDepth, kernelRows, kernelCols,
+ stridePlanes, strideRows, strideCols,
+ padding_type)
+ .reshape(pre_contract_dims)
+ .contract(kernel.reshape(kernel_dims), contract_dims)
+ .reshape(post_contract_dims));
+}
+
+} // end namespace Eigen
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_CUBOID_CONVOLUTION_H_
diff --git a/tensorflow/core/kernels/eigen_patch_3d.h b/tensorflow/core/kernels/eigen_patch_3d.h
new file mode 100644
index 0000000000..900d406709
--- /dev/null
+++ b/tensorflow/core/kernels/eigen_patch_3d.h
@@ -0,0 +1,257 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_PATCH_3D_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_PATCH_3D_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+#if not defined(__CUDACC__)
+#include <type_traits>
+#endif
+
+namespace Eigen {
+namespace internal {
+
+/** Extract3DPatches
+ * \ingroup CXX11_NeuralNetworksModule
+ *
+ * \brief Extracts 3D patches from a multichannel input volume.
+ *
+ * The input parameter is expected to be a tensor with a rank of 4 or more
+ * (channels, depth, height, width, optional others in col-major, and the
+ * reverse order in row-major).
+
+ * The return value will be a tensor of 3 more dimension than the input tensor.
+ * In col-major, the first 4 dimensions of the result are: channels, patch_depth,
+ * patch_height, patch_width. The next dimensions will identify the patch
+ * position on the 3D grid of extracted patches: z, y, x. The remaining
+ * dimensions, if any, will be the same as the 'other' dimensions of the input
+ * tensor.
+ */
+
+template <typename Input>
+EIGEN_ALWAYS_INLINE static const TensorStridingOp<
+ const array<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions + 3>,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions + 3>,
+ const TensorPatchOp<
+ const DSizes<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const TensorPaddingOp<
+ const array<IndexPair<typename internal::traits<Input>::Index>,
+ internal::traits<Input>::NumDimensions>,
+ const Input> > > >
+Extract3DPatches(
+ const Input& input, const DenseIndex patchPlanes,
+ const DenseIndex patchRows, const DenseIndex patchCols,
+ const DenseIndex stridePlanes, const DenseIndex strideRows,
+ const DenseIndex strideCols,
+ const DenseIndex paddingZTop, const DenseIndex paddingZBottom,
+ const DenseIndex paddingTop, const DenseIndex paddingBottom,
+ const DenseIndex paddingLeft, const DenseIndex paddingRight,
+ const typename internal::traits<Input>::Scalar padding_value = 0) {
+
+ typedef typename internal::traits<Input>::Index TensorIndex;
+ TensorRef<Tensor<typename internal::traits<Input>::Scalar, internal::traits<Input>::NumDimensions, internal::traits<Input>::Layout, TensorIndex> > in(input);
+
+ EIGEN_STATIC_ASSERT(internal::traits<Input>::NumDimensions >= 4, YOU_MADE_A_PROGRAMMING_MISTAKE);
+
+ static const bool isColMajor = (internal::traits<Input>::Layout == ColMajor);
+ static const int NumDims = internal::traits<Input>::NumDimensions;
+ static const int ExtDims = NumDims + 3;
+
+ // Tensor size after patch extraction. We add three dimensions to unpack the
+ // linear patch index into a 3D grid over which stride() can work.
+ DSizes<TensorIndex, ExtDims> pre_stride_dims;
+
+ if (isColMajor) {
+ pre_stride_dims[0] = in.dimension(0);
+ pre_stride_dims[1] = patchPlanes;
+ pre_stride_dims[2] = patchRows;
+ pre_stride_dims[3] = patchCols;
+ } else {
+ pre_stride_dims[ExtDims - 1] = in.dimension(NumDims - 1);
+ pre_stride_dims[ExtDims - 4] = patchCols;
+ pre_stride_dims[ExtDims - 3] = patchRows;
+ pre_stride_dims[ExtDims - 2] = patchPlanes;
+ }
+
+ const TensorIndex inputPlanes = isColMajor ? in.dimension(1) : in.dimension(NumDims - 2);
+ const TensorIndex inputRows = isColMajor ? in.dimension(2) : in.dimension(NumDims - 3);
+ const TensorIndex inputCols = isColMajor ? in.dimension(3) : in.dimension(NumDims - 4);
+
+ array<IndexPair<TensorIndex>, NumDims> paddings;
+ for (int i = 0; i < NumDims; ++i) {
+ paddings[i] = IndexPair<TensorIndex>(0, 0);
+ }
+
+ paddings[isColMajor ? 1 : (NumDims - 2)] = IndexPair<TensorIndex>(paddingZTop, paddingZBottom);
+ paddings[isColMajor ? 2 : (NumDims - 3)] = IndexPair<TensorIndex>(paddingTop, paddingBottom);
+ paddings[isColMajor ? 3 : (NumDims - 4)] = IndexPair<TensorIndex>(paddingLeft, paddingRight);
+
+ pre_stride_dims[isColMajor ? 4 : (ExtDims - 5)] = inputPlanes + paddingZBottom + paddingZTop - patchPlanes + 1;
+ pre_stride_dims[isColMajor ? 5 : (ExtDims - 6)] = inputRows + paddingTop + paddingBottom - patchRows + 1;
+ pre_stride_dims[isColMajor ? 6 : (ExtDims - 7)] = inputCols + paddingLeft + paddingRight - patchCols + 1;
+
+ if (isColMajor) {
+ for (int i = 7; i < NumDims + 3; ++i) {
+ pre_stride_dims[i] = in.dimension(i - 3);
+ }
+ } else {
+ for (int i = 0; i < NumDims - 4; ++i) {
+ pre_stride_dims[i] = in.dimension(i);
+ }
+ }
+
+ DSizes<TensorIndex, NumDims> patch_dims;
+ if (isColMajor) {
+ patch_dims[0] = in.dimension(0);
+ patch_dims[1] = patchPlanes;
+ patch_dims[2] = patchRows;
+ patch_dims[3] = patchCols;
+ for (int i = 4; i < NumDims; ++i) {
+ patch_dims[i] = 1;
+ }
+ } else {
+ patch_dims[NumDims - 1] = in.dimension(NumDims - 1);
+ patch_dims[NumDims - 4] = patchCols;
+ patch_dims[NumDims - 3] = patchRows;
+ patch_dims[NumDims - 2] = patchPlanes;
+ for (int i = 0; i < NumDims - 4; i++) {
+ patch_dims[i] = 1;
+ }
+ }
+
+ array<TensorIndex, NumDims + 3> strides;
+ if (isColMajor) {
+ // No striding within the patches.
+ for (int i = 0; i < 4; ++i) {
+ strides[i] = 1;
+ }
+ // Apply striding in the spatial patch grid dimensions only.
+ strides[4] = stridePlanes;
+ strides[5] = strideRows;
+ strides[6] = strideCols;
+ // No striding in the remaining dimensions (batches, ...).
+ for (int i = 7; i < NumDims + 3; i++) {
+ strides[i] = 1;
+ }
+ } else {
+ // No striding within the patches.
+ for (int i = 1; i <= 4; ++i) {
+ strides[ExtDims - i] = 1;
+ }
+ // Apply striding in the spatial patch grid dimensions only.
+ strides[ExtDims - 7] = strideCols;
+ strides[ExtDims - 6] = strideRows;
+ strides[ExtDims - 5] = stridePlanes;
+ // No striding in the remaining dimensions (batches, ...).
+ for (int i = 0; i < NumDims - 4; i++) {
+ strides[i] = 1;
+ }
+ }
+
+ // TODO(mjanusz): Consider getting rid of pad(), and stride() and extend
+ // extract_patches to take additional parameters for padding/striding,
+ // similarly to etract_image_patches.
+ return input.pad(paddings, padding_value).extract_patches(patch_dims).reshape(pre_stride_dims).stride(strides);
+}
+
+
+template <typename Input>
+EIGEN_ALWAYS_INLINE static const TensorStridingOp<
+ const array<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions + 3>,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions + 3>,
+ const TensorPatchOp<
+ const DSizes<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const TensorPaddingOp<
+ const array<IndexPair<typename internal::traits<Input>::Index>,
+ internal::traits<Input>::NumDimensions>,
+ const Input> > > >
+Extract3DPatches(
+ const Input& input, const DenseIndex patchPlanes,
+ const DenseIndex patchRows, const DenseIndex patchCols,
+ const DenseIndex stridePlanes, const DenseIndex strideRows,
+ const DenseIndex strideCols, const PaddingType padding_type,
+ const typename internal::traits<Input>::Scalar padding_value = 0) {
+ typedef typename internal::traits<Input>::Index TensorIndex;
+ TensorRef<Tensor<typename internal::traits<Input>::Scalar, internal::traits<Input>::NumDimensions, internal::traits<Input>::Layout, TensorIndex> > in(input);
+
+ EIGEN_STATIC_ASSERT(internal::traits<Input>::NumDimensions >= 4, YOU_MADE_A_PROGRAMMING_MISTAKE);
+
+ static const bool isColMajor = (internal::traits<Input>::Layout == ColMajor);
+ static const int NumDims = internal::traits<Input>::NumDimensions;
+
+ const TensorIndex inputPlanes = isColMajor ? in.dimension(1) : in.dimension(NumDims - 2);
+ const TensorIndex inputRows = isColMajor ? in.dimension(2) : in.dimension(NumDims - 3);
+ const TensorIndex inputCols = isColMajor ? in.dimension(3) : in.dimension(NumDims - 4);
+
+ switch (padding_type) {
+ case PADDING_VALID:
+ // No padding in any dimension.
+ return Extract3DPatches(input, patchPlanes, patchRows, patchCols,
+ stridePlanes, strideRows, strideCols,
+ 0, 0, 0, 0, 0, 0, padding_value);
+ case PADDING_SAME: {
+ // The side of the tensor before striding should be just the expected
+ // output times the stride.
+ const TensorIndex size_z = ceil(inputPlanes / static_cast<float>(stridePlanes)) * stridePlanes;
+ const TensorIndex size_y = ceil(inputRows / static_cast<float>(strideRows)) * strideRows;
+ const TensorIndex size_x = ceil(inputCols / static_cast<float>(strideCols)) * strideCols;
+
+ // The size of the patch space is going to be: padded_input_size - patch_size + 1.
+ // This has to match the expected size before striding (pre_stride_dims).
+ // The deltas below extend the input to the expected size.
+ const TensorIndex dz = size_z + patchPlanes - 1 - inputPlanes;
+ const TensorIndex dy = size_y + patchRows - 1 - inputRows;
+ const TensorIndex dx = size_x + patchCols - 1 - inputCols;
+
+ return Extract3DPatches(input, patchPlanes, patchRows, patchCols,
+ stridePlanes, strideRows, strideCols,
+ dz - dz / 2, dz / 2,
+ dy - dy / 2, dy / 2,
+ dx - dx / 2, dx / 2,
+ padding_value);
+ }
+ default:
+ eigen_assert(false && "unexpected padding");
+ // unreachable code to avoid missing return warning.
+ return Extract3DPatches(input, patchPlanes, patchRows, patchCols,
+ stridePlanes, strideRows, strideCols,
+ 0, 0, 0, 0, 0, 0, padding_value);
+ }
+}
+
+// TODO(mjanusz): Switch this to a 'using' alias once CUDA supports C++11.
+template <typename Input>
+struct Extract3DPatchesType {
+ typedef const TensorStridingOp< const array<typename internal::traits<Input>::Index, internal::traits<Input>::NumDimensions + 3>,
+ const TensorReshapingOp< const DSizes<typename internal::traits<Input>::Index, internal::traits<Input>::NumDimensions + 3>,
+ const TensorPatchOp< const DSizes<typename internal::traits<Input>::Index, internal::traits<Input>::NumDimensions>,
+ const TensorPaddingOp< const array< IndexPair<typename internal::traits<Input>::Index>, internal::traits<Input>::NumDimensions>,
+ const Input> > > > type;
+};
+
+} // end namespace internal
+} // end namespace Eigen
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_PATCH_3D_H_
diff --git a/tensorflow/core/kernels/eigen_pooling.h b/tensorflow/core/kernels/eigen_pooling.h
new file mode 100644
index 0000000000..7ded806b74
--- /dev/null
+++ b/tensorflow/core/kernels/eigen_pooling.h
@@ -0,0 +1,441 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_POOLING_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_POOLING_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/kernels/eigen_patch_3d.h"
+
+namespace Eigen {
+
+/** SpatialMaxPooling
+ * \ingroup CXX11_NeuralNetworks_Module
+ *
+ * \brief Applies a max-pooling over a multichannel input image.
+ *
+ * The input parameter is expected to be a with a rank of 4 (channels, height, width, others in col-major, and the reverse of that in row-major).
+ *
+ * The result can be assigned to a tensor of rank equal to the rank of the input. The dimensions of the result will be channels, height, width, and others (in col-major, and the reverse of that if the input was row-major).
+ *
+ * The order of the width and height dimensions can be swapped if needed.
+ *
+*/
+#if !defined(EIGEN_HAS_INDEX_LIST)
+template <typename Input>
+EIGEN_ALWAYS_INLINE
+static const TensorReshapingOp<const Eigen::DSizes<typename internal::traits<Input>::Index, internal::traits<Input>::NumDimensions>, const TensorReductionOp<internal::MaxReducer<typename internal::remove_const<typename internal::traits<Input>::Scalar>::type>, const Eigen::array<int, 2>, const TensorImagePatchOp<Dynamic, Dynamic, const Input> > >
+#else
+template <typename Input>
+EIGEN_ALWAYS_INLINE
+static const TensorReshapingOp<const Eigen::DSizes<typename internal::traits<Input>::Index, internal::traits<Input>::NumDimensions>, const TensorReductionOp<internal::MaxReducer<typename internal::remove_const<typename internal::traits<Input>::Scalar>::type>, typename internal::conditional<internal::traits<Input>::Layout == ColMajor, const Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<2> >, const Eigen::IndexList<Eigen::type2index<2>, Eigen::type2index<3> > >::type, const TensorImagePatchOp<Dynamic, Dynamic, const Input> > >
+#endif
+SpatialMaxPooling(const Input& input, DenseIndex patchRows, DenseIndex patchCols,
+ DenseIndex strideRows, DenseIndex strideCols, const PaddingType padding_type,
+ DenseIndex in_strideRows = 1, DenseIndex in_strideCols = 1)
+{
+ EIGEN_STATIC_ASSERT(internal::traits<Input>::NumDimensions == 4, YOU_MADE_A_PROGRAMMING_MISTAKE);
+
+ typedef typename internal::traits<Input>::Index TensorIndex;
+ TensorRef<Tensor<typename internal::traits<Input>::Scalar, internal::traits<Input>::NumDimensions, internal::traits<Input>::Layout, TensorIndex> > in(input);
+
+ const DenseIndex patchRowsEff = patchRows + (patchRows - 1) * (in_strideRows - 1);
+ const DenseIndex patchColsEff = patchCols + (patchCols - 1) * (in_strideCols - 1);
+
+ static const bool isColMajor = (internal::traits<Input>::Layout == ColMajor);
+ static const int idxRows = isColMajor ? 1 : 2;
+ static const int idxCols = isColMajor ? 2 : 1;
+
+ // Molds the output of the reduction into the shape expected by the user.
+ // (assuming col-major):
+ // - 1st dim: channels
+ // - 2nd dim: output height
+ // - 3rd dim: output width
+ // - 4th dim and beyond: everything else including batch size
+ Eigen::DSizes<TensorIndex, internal::traits<Input>::NumDimensions> post_reduce_dims;
+ post_reduce_dims[0] = in.dimension(0);
+ if (padding_type == PADDING_VALID) {
+ post_reduce_dims[idxRows] = numext::ceil((in.dimension(idxRows) - patchRowsEff + 1.f) / static_cast<float>(strideRows));
+ post_reduce_dims[idxCols] = numext::ceil((in.dimension(idxCols) - patchColsEff + 1.f) / static_cast<float>(strideCols));
+ } else {
+ post_reduce_dims[idxRows] = numext::ceil(in.dimension(idxRows) / static_cast<float>(strideRows));
+ post_reduce_dims[idxCols] = numext::ceil(in.dimension(idxCols) / static_cast<float>(strideCols));
+ }
+ post_reduce_dims[3] = in.dimension(3);
+
+#if !defined(EIGEN_HAS_INDEX_LIST)
+ // nvcc doesn't support cxx11
+ Eigen::array<int, 2> reduction_dims;
+ if (isColMajor) {
+ reduction_dims[0] = 1;
+ reduction_dims[1] = 2;
+ } else {
+ reduction_dims[0] = 2;
+ reduction_dims[1] = 3;
+ }
+#else
+ // Take advantage of cxx11 to give the compiler information it can use to
+ // optimize the code.
+ typename internal::conditional<internal::traits<Input>::Layout == ColMajor, const Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<2> >, const Eigen::IndexList<Eigen::type2index<2>, Eigen::type2index<3> > >::type reduction_dims;
+#endif
+
+ return input.extract_image_patches(patchRows, patchCols, strideRows, strideCols, in_strideRows, in_strideCols, padding_type, -Eigen::NumTraits<typename internal::remove_const<typename internal::traits<Input>::Scalar>::type>::highest()).maximum(reduction_dims).reshape(post_reduce_dims);
+}
+
+/** CuboidMaxPooling
+ * \ingroup CXX11_NeuralNetworks_Module
+ *
+ * \brief Applies a max-pooling over a multichannel input volume.
+ *
+ * The input parameter is expected to be a tensor with a rank of 5 (channels, depth, height, width, others in col-major, and the reverse of that in row-major).
+ *
+ * The result can be assigned to a tensor of rank equal to the rank of the input. The dimensions of the result will be channels, depth, height, width, and others (in col-major, and the reverse of that if the input was row-major).
+ *
+ * The order of the depth, width and height dimensions can be swapped if needed.
+ *
+*/
+#if !defined(EIGEN_HAS_INDEX_LIST)
+template <typename Input>
+EIGEN_ALWAYS_INLINE static const TensorReshapingOp<
+ const Eigen::DSizes<DenseIndex, internal::traits<Input>::NumDimensions>,
+ const TensorReductionOp<
+ internal::MaxReducer<float>, const Eigen::array<int, 1>,
+ const TensorReshapingOp<
+ const Eigen::DSizes<DenseIndex, 3>,
+ const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic, const Input> > > >
+#else
+template <typename Input>
+EIGEN_ALWAYS_INLINE static const TensorReshapingOp<
+ const Eigen::DSizes<DenseIndex, internal::traits<Input>::NumDimensions>,
+ const TensorReductionOp<
+ internal::MaxReducer<float>,
+ const Eigen::IndexList<Eigen::type2index<1> >,
+ const TensorReshapingOp<
+ const Eigen::DSizes<DenseIndex, 3>,
+ const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic, const Input> > > >
+#endif
+CuboidMaxPooling(const Input& input, DenseIndex patchPlanes,
+ DenseIndex patchRows, DenseIndex patchCols,
+ DenseIndex stridePlanes, DenseIndex strideRows,
+ DenseIndex strideCols, const PaddingType padding_type) {
+ EIGEN_STATIC_ASSERT(internal::traits<Input>::NumDimensions == 5, YOU_MADE_A_PROGRAMMING_MISTAKE);
+ static const bool isColMajor = (internal::traits<Input>::Layout == ColMajor);
+
+ typedef typename internal::traits<Input>::Index TensorIndex;
+ TensorRef<Tensor<typename internal::traits<Input>::Scalar, internal::traits<Input>::NumDimensions, internal::traits<Input>::Layout, TensorIndex> > in(input);
+
+ static const int idxPlanes = isColMajor ? 1 : 3;
+ static const int idxRows = 2;
+ static const int idxCols = isColMajor ? 3 : 1;
+
+ // Molds the output of the reduction into the shape expected by the used
+ // (assuming col-major):
+ // - 1st dim: channels
+ // - 2nd dim: output depth
+ // - 3rd dim: output height
+ // - 4th dim: output width
+ // - 5th dim and beyond: everything else including batch size
+ Eigen::DSizes<DenseIndex, internal::traits<Input>::NumDimensions> post_reduce_dims;
+ post_reduce_dims[0] = in.dimension(0);
+ if (padding_type == PADDING_VALID) {
+ post_reduce_dims[idxPlanes] = numext::ceil((in.dimension(idxPlanes) - patchPlanes + 1.f) / static_cast<float>(stridePlanes));
+ post_reduce_dims[idxRows] = numext::ceil((in.dimension(idxRows) - patchRows + 1.f) / static_cast<float>(strideRows));
+ post_reduce_dims[idxCols] = numext::ceil((in.dimension(idxCols) - patchCols + 1.f) / static_cast<float>(strideCols));
+ } else {
+ post_reduce_dims[idxPlanes] = numext::ceil(in.dimension(idxPlanes) / static_cast<float>(stridePlanes));
+ post_reduce_dims[idxRows] = numext::ceil(in.dimension(idxRows) / static_cast<float>(strideRows));
+ post_reduce_dims[idxCols] = numext::ceil(in.dimension(idxCols) / static_cast<float>(strideCols));
+ }
+ post_reduce_dims[4] = in.dimension(4);
+
+ Eigen::DSizes<DenseIndex, 3> pre_reduce_dims;
+ pre_reduce_dims[1] = patchRows * patchCols * patchPlanes;
+ if (isColMajor) {
+ pre_reduce_dims[0] = post_reduce_dims[0];
+ pre_reduce_dims[2] = post_reduce_dims[1] * post_reduce_dims[2] * post_reduce_dims[3] * post_reduce_dims[4];
+ } else {
+ pre_reduce_dims[0] = post_reduce_dims[0] * post_reduce_dims[1] * post_reduce_dims[2] * post_reduce_dims[3];
+ pre_reduce_dims[2] = post_reduce_dims[4];
+ }
+
+#if !defined(EIGEN_HAS_INDEX_LIST)
+ // nvcc doesn't support cxx11
+ Eigen::array<int, 1> reduction_dims;
+ reduction_dims[0] = 1;
+#else
+ // Take advantage of cxx11 to give the compiler information it can use to
+ // optimize the code.
+ Eigen::IndexList<Eigen::type2index<1> > reduction_dims;
+#endif
+ return input.extract_volume_patches(patchPlanes, patchRows, patchCols,
+ stridePlanes, strideRows, strideCols,
+ padding_type, -Eigen::NumTraits<float>::highest())
+ .reshape(pre_reduce_dims)
+ .maximum(reduction_dims)
+ .reshape(post_reduce_dims);
+}
+
+
+/** SpatialAvgPooling
+ * \ingroup CXX11_NeuralNetworks_Module
+ *
+ * \brief Applies an average pooling over a multichannel input image.
+ *
+ * The input parameter is expected to be a tensor with a rank of 4 (channels, height, width, others in col-major, and the reverse of that in row-major).
+ *
+ * The result can be assigned to a tensor of rank equal to the rank of the input. The dimensions of the result will be channels, height, width, and others (in col-major, and the reverse of that if the input was row-major).
+ *
+ * The order of the width and height dimensions can be swapped if needed.
+ *
+*/
+namespace internal {
+
+template <typename T> struct AvgPoolMeanReducer
+{
+#if (EIGEN_ARCH_i386 || EIGEN_ARCH_x86_64) && !defined(__CUDACC__)
+ // We only support packet access for floats.
+ static const bool PacketAccess = internal::is_same<T, float>::value;
+#else
+ static const bool PacketAccess = false;
+#endif
+ static const bool IsStateful = true;
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE AvgPoolMeanReducer() : scalarCount_(0) {
+ typedef typename packet_traits<T>::type Packet;
+ packetCount_ = pset1<Packet>(0.0);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) {
+ if (t != -Eigen::NumTraits<T>::highest()) {
+ (*accum) = (*accum) + t;
+ scalarCount_++;
+ }
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
+ return static_cast<T>(0);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const {
+ eigen_assert(scalarCount_ > 0);
+ return accum / scalarCount_;
+ }
+
+#if (EIGEN_ARCH_i386 || EIGEN_ARCH_x86_64) && !defined(__CUDACC__)
+#ifdef EIGEN_VECTORIZE_AVX
+#define pequal(a,b) _mm256_cmp_ps(a,b,_CMP_EQ_UQ)
+#define psel(a,b,false_mask) _mm256_blendv_ps(a,b,false_mask)
+#else
+#define pequal(a,b) _mm_cmpeq_ps(a,b)
+#define psel(a,b,false_mask) _mm_or_ps(_mm_andnot_ps(false_mask, a), _mm_and_ps(false_mask, b))
+#endif
+
+ template <typename Packet>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reducePacket(const Packet& p, Packet* accum) {
+ reducePacketWithType(static_cast<T>(0), p, accum);
+ }
+
+ template <typename Packet>
+ void reducePacketWithType(T, const Packet& p, Packet* accum) {
+ Packet skip_mask = pequal(p, pset1<Packet>(-Eigen::NumTraits<T>::highest()));
+ (*accum) = padd<Packet>(*accum, psel(p, pset1<Packet>(0), skip_mask));
+ packetCount_ = padd<Packet>(packetCount_, psel(pset1<Packet>(1), pset1<Packet>(0), skip_mask));
+ }
+
+ template <typename Packet>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet initializePacket() const {
+ return pset1<Packet>(0);
+ }
+
+ template <typename Packet>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet finalizePacket(const Packet& vaccum) const {
+ return pdiv(vaccum, packetCount_);
+ }
+ template <typename Packet>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalizeBoth(const T saccum, const Packet& vaccum) const {
+ return (saccum + predux(vaccum)) / (scalarCount_ + predux(packetCount_));
+ }
+#endif
+
+ protected:
+ typedef typename packet_traits<T>::type Packet;
+ int scalarCount_;
+ Packet packetCount_;
+};
+
+} // namespace internal
+
+#if !defined(EIGEN_HAS_INDEX_LIST)
+template <typename Input>
+EIGEN_ALWAYS_INLINE
+static const TensorReshapingOp<const Eigen::DSizes<typename internal::traits<Input>::Index, internal::traits<Input>::NumDimensions>, const TensorReductionOp<internal::AvgPoolMeanReducer<typename internal::remove_const<typename internal::traits<Input>::Scalar>::type>, const Eigen::array<int, 2>, const TensorImagePatchOp<Dynamic, Dynamic, const Input> > >
+#else
+template <typename Input>
+EIGEN_ALWAYS_INLINE
+static const TensorReshapingOp<const Eigen::DSizes<typename internal::traits<Input>::Index, internal::traits<Input>::NumDimensions>, const TensorReductionOp<internal::AvgPoolMeanReducer<typename internal::remove_const<typename internal::traits<Input>::Scalar>::type>, typename internal::conditional<internal::traits<Input>::Layout == ColMajor, const Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<2> >, const Eigen::IndexList<Eigen::type2index<2>, Eigen::type2index<3> > >::type, const TensorImagePatchOp<Dynamic, Dynamic, const Input> > >
+#endif
+SpatialAvgPooling(const Input& input, DenseIndex patchRows, DenseIndex patchCols,
+ DenseIndex strideRows, DenseIndex strideCols, const PaddingType padding_type,
+ DenseIndex in_strideRows = 1, DenseIndex in_strideCols = 1)
+{
+ EIGEN_STATIC_ASSERT(internal::traits<Input>::NumDimensions == 4, YOU_MADE_A_PROGRAMMING_MISTAKE);
+
+ typedef typename internal::traits<Input>::Index TensorIndex;
+ TensorRef<Tensor<typename internal::traits<Input>::Scalar, internal::traits<Input>::NumDimensions, internal::traits<Input>::Layout, TensorIndex> > in(input);
+
+ const DenseIndex patchRowsEff = patchRows + (patchRows - 1) * (in_strideRows - 1);
+ const DenseIndex patchColsEff = patchCols + (patchCols - 1) * (in_strideCols - 1);
+
+ static const bool isColMajor = (internal::traits<Input>::Layout == ColMajor);
+ static const int idxRows = isColMajor ? 1 : 2;
+ static const int idxCols = isColMajor ? 2 : 1;
+
+ // Molds the output of the reduction into the shape expected by the user.
+ // (assuming col-major):
+ // - 1st dim: channels
+ // - 2nd dim: output height
+ // - 3rd dim: output width
+ // - 4th dim and beyond: everything else including batch size
+ Eigen::DSizes<TensorIndex, internal::traits<Input>::NumDimensions> post_reduce_dims;
+ post_reduce_dims[0] = in.dimension(0);
+ if (padding_type == PADDING_VALID) {
+ post_reduce_dims[idxRows] = numext::ceil((in.dimension(idxRows) - patchRowsEff + 1.f) / static_cast<float>(strideRows));
+ post_reduce_dims[idxCols] = numext::ceil((in.dimension(idxCols) - patchColsEff + 1.f) / static_cast<float>(strideCols));
+ } else {
+ post_reduce_dims[idxRows] = numext::ceil(in.dimension(idxRows) / static_cast<float>(strideRows));
+ post_reduce_dims[idxCols] = numext::ceil(in.dimension(idxCols) / static_cast<float>(strideCols));
+ }
+ post_reduce_dims[3] = in.dimension(3);
+
+ typedef typename internal::remove_const<typename internal::traits<Input>::Scalar>::type CoeffReturnType;
+ internal::AvgPoolMeanReducer<CoeffReturnType> mean_with_nan;
+
+#if !defined(EIGEN_HAS_INDEX_LIST)
+ // nvcc doesn't support cxx11
+ Eigen::array<int, 2> reduction_dims;
+ if (isColMajor) {
+ reduction_dims[0] = 1;
+ reduction_dims[1] = 2;
+ } else {
+ reduction_dims[0] = 2;
+ reduction_dims[1] = 3;
+ }
+#else
+ // Take advantage of cxx11 to give the compiler information it can use to
+ // optimize the code.
+ typename internal::conditional<internal::traits<Input>::Layout == ColMajor, const Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<2> >, const Eigen::IndexList<Eigen::type2index<2>, Eigen::type2index<3> > >::type reduction_dims;
+#endif
+ return input.extract_image_patches(patchRows, patchCols, strideRows, strideCols, in_strideRows, in_strideCols, padding_type, -Eigen::NumTraits<typename internal::remove_const<typename internal::traits<Input>::Scalar>::type>::highest()).reduce(reduction_dims, mean_with_nan).reshape(post_reduce_dims);
+}
+
+
+/** CuboidAvgPooling
+ * \ingroup CXX11_NeuralNetworks_Module
+ *
+ * \brief Applies an average pooling over a multichannel input volume.
+ *
+ * The input parameter is expected to be a tensor with a rank of 5 (channels, depth, height, width, others, and the reverse of that in row-major).
+ *
+ * The result can be assigned to a tensor of rank equal to the rank of the input. The dimensions of the result will be channels, depth, width, and others (in col-major, and the reverse of that if the input was row-major).
+ *
+ * The order of the depth, width and height dimensions can be swapped if needed.
+ *
+*/
+#if !defined(EIGEN_HAS_INDEX_LIST)
+template <typename Input>
+EIGEN_ALWAYS_INLINE static const TensorReshapingOp<
+ const Eigen::DSizes<DenseIndex, internal::traits<Input>::NumDimensions>,
+ const TensorReductionOp<
+ internal::AvgPoolMeanReducer<float>, const Eigen::array<int, 1>,
+ const TensorReshapingOp<
+ const Eigen::DSizes<DenseIndex, 3>,
+ const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic, const Input> > > >
+#else
+template <typename Input>
+EIGEN_ALWAYS_INLINE static const TensorReshapingOp<
+ const Eigen::DSizes<DenseIndex, internal::traits<Input>::NumDimensions>,
+ const TensorReductionOp<
+ internal::AvgPoolMeanReducer<float>,
+ const Eigen::IndexList<Eigen::type2index<1> >,
+ const TensorReshapingOp<
+ const Eigen::DSizes<DenseIndex, 3>,
+ const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic, const Input> > > >
+#endif
+CuboidAvgPooling(const Input& input, DenseIndex patchPlanes,
+ DenseIndex patchRows, DenseIndex patchCols,
+ DenseIndex stridePlanes, DenseIndex strideRows,
+ DenseIndex strideCols, const PaddingType padding_type) {
+ EIGEN_STATIC_ASSERT(internal::traits<Input>::NumDimensions == 5, YOU_MADE_A_PROGRAMMING_MISTAKE);
+ static const bool isColMajor = (internal::traits<Input>::Layout == ColMajor);
+
+ typedef typename internal::traits<Input>::Index TensorIndex;
+ TensorRef<Tensor<typename internal::traits<Input>::Scalar, internal::traits<Input>::NumDimensions, internal::traits<Input>::Layout, TensorIndex> > in(input);
+
+ static const int idxPlanes = isColMajor ? 1 : 3;
+ static const int idxRows = 2;
+ static const int idxCols = isColMajor ? 3 : 1;
+ // Molds the output of the reduction into the shape expected by the used
+ // (assuming col-major):
+ // - 1st dim: channels
+ // - 2nd dim: outupt depth
+ // - 3rd dim: output height
+ // - 4th dim: output width
+ // - 5th dim and beyond: everything else including batch size
+ Eigen::DSizes<DenseIndex, internal::traits<Input>::NumDimensions> post_reduce_dims;
+ post_reduce_dims[0] = in.dimension(0);
+ if (padding_type == PADDING_VALID) {
+ post_reduce_dims[idxPlanes] = numext::ceil((in.dimension(idxPlanes) - patchPlanes + 1.f) / static_cast<float>(stridePlanes));
+ post_reduce_dims[idxRows] = numext::ceil((in.dimension(idxRows) - patchRows + 1.f) / static_cast<float>(strideRows));
+ post_reduce_dims[idxCols] = numext::ceil((in.dimension(idxCols) - patchCols + 1.f) / static_cast<float>(strideCols));
+ } else {
+ post_reduce_dims[idxPlanes] = numext::ceil(in.dimension(idxPlanes) / static_cast<float>(stridePlanes));
+ post_reduce_dims[idxRows] = numext::ceil(in.dimension(idxRows) / static_cast<float>(strideRows));
+ post_reduce_dims[idxCols] = numext::ceil(in.dimension(idxCols) / static_cast<float>(strideCols));
+ }
+ post_reduce_dims[4] = in.dimension(4);
+
+ Eigen::DSizes<DenseIndex, 3> pre_reduce_dims;
+ pre_reduce_dims[1] = patchRows * patchCols * patchPlanes;
+ if (isColMajor) {
+ pre_reduce_dims[0] = post_reduce_dims[0];
+ pre_reduce_dims[2] = post_reduce_dims[1] * post_reduce_dims[2] * post_reduce_dims[3] * post_reduce_dims[4];
+ } else {
+ pre_reduce_dims[0] = post_reduce_dims[0] * post_reduce_dims[1] * post_reduce_dims[2] * post_reduce_dims[3];
+ pre_reduce_dims[2] = post_reduce_dims[4];
+ }
+
+ typedef typename internal::remove_const<typename internal::traits<Input>::Scalar>::type CoeffReturnType;
+ internal::AvgPoolMeanReducer<CoeffReturnType> mean_with_nan;
+
+#if !defined(EIGEN_HAS_INDEX_LIST)
+ // nvcc doesn't support cxx11
+ Eigen::array<int, 1> reduction_dims;
+ reduction_dims[0] = 1;
+#else
+ // Take advantage of cxx11 to give the compiler information it can use to
+ // optimize the code.
+ Eigen::IndexList<Eigen::type2index<1> > reduction_dims;
+#endif
+ return input.extract_volume_patches(patchPlanes, patchRows, patchCols,
+ stridePlanes, strideRows, strideCols,
+ padding_type, -Eigen::NumTraits<float>::highest())
+ .reshape(pre_reduce_dims)
+ .reduce(reduction_dims, mean_with_nan)
+ .reshape(post_reduce_dims);
+}
+
+} // end namespace Eigen
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_POOLING_H_
diff --git a/tensorflow/core/kernels/eigen_pooling_test.cc b/tensorflow/core/kernels/eigen_pooling_test.cc
new file mode 100644
index 0000000000..cf6957571f
--- /dev/null
+++ b/tensorflow/core/kernels/eigen_pooling_test.cc
@@ -0,0 +1,742 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/eigen_pooling.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace Eigen {
+
+namespace {
+void EigenApprox(float a, float b) {
+ ASSERT_TRUE(std::abs(a - b) <= std::min(std::abs(a), std::abs(b)) * 1e-3);
+}
+}
+
+TEST(EigenPoolingTest, Simple) {
+ const int depth = 10;
+ const int input_rows = 5;
+ const int input_cols = 5;
+ const int num_batches = 13;
+ const int patch_rows = 4;
+ const int patch_cols = 4;
+ const int output_rows = 2;
+ const int output_cols = 2;
+
+ Tensor<float, 4> input(depth, input_rows, input_cols, num_batches);
+ Tensor<float, 4> result(depth, output_rows, output_cols, num_batches);
+ input = input.constant(11.0f) + input.random();
+ result.setRandom();
+ result = result.constant(-1000.f);
+
+ // Max pooling using a 4x4 window and a stride of 1.
+ const int stride = 1;
+ result = SpatialMaxPooling(input, patch_rows, patch_cols, stride, stride,
+ PADDING_VALID);
+
+ EXPECT_EQ(result.dimension(0), depth);
+ EXPECT_EQ(result.dimension(1), output_rows);
+ EXPECT_EQ(result.dimension(2), output_cols);
+ EXPECT_EQ(result.dimension(3), num_batches);
+
+ for (int b = 0; b < num_batches; ++b) {
+ for (int d = 0; d < depth; ++d) {
+ for (int i = 0; i < output_rows; ++i) {
+ for (int j = 0; j < output_cols; ++j) {
+ float expected = -10000.f;
+ for (int r = 0; r < patch_rows; ++r) {
+ for (int c = 0; c < patch_cols; ++c) {
+ expected = (std::max)(expected, input(d, r + i, c + j, b));
+ }
+ }
+ if (result(d, i, j, b) != expected) {
+ std::cout << "at d=" << d << " b=" << b << " i=" << i << " j=" << j
+ << " " << result(d, i, j, b) << " vs " << expected
+ << std::endl;
+ }
+ EigenApprox(result(d, i, j, b), expected);
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenPoolingTest, SimpleRowMajor) {
+ const int depth = 10;
+ const int input_rows = 5;
+ const int input_cols = 5;
+ const int num_batches = 13;
+ const int patch_rows = 4;
+ const int patch_cols = 4;
+ const int output_rows = 2;
+ const int output_cols = 2;
+
+ Tensor<float, 4, RowMajor> input(num_batches, input_cols, input_rows, depth);
+ Tensor<float, 4, RowMajor> result(num_batches, output_cols, output_rows,
+ depth);
+ input = input.constant(11.0f) + input.random();
+ result.setRandom();
+ result = result.constant(-1000.f);
+
+ // Max pooling using a 4x4 window and a stride of 1.
+ const int stride = 1;
+ result = SpatialMaxPooling(input, patch_rows, patch_cols, stride, stride,
+ PADDING_VALID);
+
+ EXPECT_EQ(result.dimension(3), depth);
+ EXPECT_EQ(result.dimension(2), output_rows);
+ EXPECT_EQ(result.dimension(1), output_cols);
+ EXPECT_EQ(result.dimension(0), num_batches);
+
+ for (int b = 0; b < num_batches; ++b) {
+ for (int d = 0; d < depth; ++d) {
+ for (int i = 0; i < output_rows; ++i) {
+ for (int j = 0; j < output_cols; ++j) {
+ float expected = -10000.f;
+ for (int r = 0; r < patch_rows; ++r) {
+ for (int c = 0; c < patch_cols; ++c) {
+ expected = (std::max)(expected, input(b, c + j, r + i, d));
+ }
+ }
+ if (result(b, j, i, d) != expected) {
+ std::cout << "at d=" << d << " b=" << b << " i=" << i << " j=" << j
+ << " " << result(b, j, i, d) << " vs " << expected
+ << std::endl;
+ }
+ EigenApprox(result(b, j, i, d), expected);
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenPoolingTest, Cuboid) {
+ const int channels = 10;
+ const int input_planes = 5;
+ const int input_rows = 5;
+ const int input_cols = 5;
+ const int num_batches = 13;
+ const int patch_rows = 4;
+ const int patch_cols = 3;
+ const int patch_planes = 2;
+ const int output_rows = 2;
+ const int output_cols = 3;
+ const int output_planes = 4;
+
+ Tensor<float, 5> input(channels, input_planes, input_rows, input_cols,
+ num_batches);
+ Tensor<float, 5> result(channels, output_planes, output_rows, output_cols,
+ num_batches);
+ input = input.constant(11.0f) + input.random();
+ result.setRandom();
+ result = result.constant(-1000.0f);
+
+ // Max pooling using a 4x3x2 window and a stride of 1.
+ const int stride = 1;
+ result = CuboidMaxPooling(input, patch_planes, patch_rows, patch_cols, stride,
+ stride, stride, PADDING_VALID);
+
+ EXPECT_EQ(result.dimension(0), channels);
+ EXPECT_EQ(result.dimension(1), output_planes);
+ EXPECT_EQ(result.dimension(2), output_rows);
+ EXPECT_EQ(result.dimension(3), output_cols);
+ EXPECT_EQ(result.dimension(4), num_batches);
+
+ for (int b = 0; b < num_batches; ++b) {
+ for (int d = 0; d < channels; ++d) {
+ for (int i = 0; i < output_planes; ++i) {
+ for (int j = 0; j < output_rows; ++j) {
+ for (int k = 0; k < output_cols; ++k) {
+ float expected = -10000.f;
+ for (int p = 0; p < patch_planes; ++p) {
+ for (int r = 0; r < patch_rows; ++r) {
+ for (int c = 0; c < patch_cols; ++c) {
+ expected =
+ (std::max)(expected, input(d, p + i, r + j, c + k, b));
+ }
+ }
+ }
+ if (result(d, i, j, k, b) != expected) {
+ std::cout << "at d=" << d << " b=" << b << " i=" << i
+ << " j=" << j << " k=" << k << " "
+ << result(d, i, j, k, b) << " vs " << expected
+ << std::endl;
+ }
+ EigenApprox(result(d, i, j, k, b), expected);
+ }
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenPoolingTest, CuboidRowMajor) {
+ const int channels = 10;
+ const int input_planes = 5;
+ const int input_rows = 5;
+ const int input_cols = 5;
+ const int num_batches = 13;
+ const int patch_rows = 4;
+ const int patch_cols = 3;
+ const int patch_planes = 2;
+ const int output_rows = 2;
+ const int output_cols = 3;
+ const int output_planes = 4;
+
+ Tensor<float, 5, RowMajor> input(num_batches, input_cols, input_rows,
+ input_planes, channels);
+ Tensor<float, 5, RowMajor> result(num_batches, output_cols, output_rows,
+ output_planes, channels);
+ input = input.constant(11.0f) + input.random();
+ result.setRandom();
+ result = result.constant(-1000.0f);
+
+ // Max pooling using a 4x3x2 window and a stride of 1.
+ const int stride = 1;
+ result = CuboidMaxPooling(input, patch_planes, patch_rows, patch_cols, stride,
+ stride, stride, PADDING_VALID);
+
+ EXPECT_EQ(result.dimension(4), channels);
+ EXPECT_EQ(result.dimension(3), output_planes);
+ EXPECT_EQ(result.dimension(2), output_rows);
+ EXPECT_EQ(result.dimension(1), output_cols);
+ EXPECT_EQ(result.dimension(0), num_batches);
+
+ for (int b = 0; b < num_batches; ++b) {
+ for (int d = 0; d < channels; ++d) {
+ for (int i = 0; i < output_planes; ++i) {
+ for (int j = 0; j < output_rows; ++j) {
+ for (int k = 0; k < output_cols; ++k) {
+ float expected = -10000.f;
+ for (int p = 0; p < patch_planes; ++p) {
+ for (int r = 0; r < patch_rows; ++r) {
+ for (int c = 0; c < patch_cols; ++c) {
+ expected =
+ (std::max)(expected, input(b, c + k, r + j, p + i, d));
+ }
+ }
+ }
+ if (result(b, k, j, i, d) != expected) {
+ std::cout << "at d=" << d << " b=" << b << " i=" << i
+ << " j=" << j << " k=" << k << " "
+ << result(b, k, j, i, d) << " vs " << expected
+ << std::endl;
+ }
+ EigenApprox(result(b, k, j, i, d), expected);
+ }
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenPoolingTest, ValidCuboid) {
+ const int channels = 10;
+ const int input_planes = 5;
+ const int input_rows = 5;
+ const int input_cols = 5;
+ const int num_batches = 13;
+ const int patch_rows = 4;
+ const int patch_cols = 3;
+ const int patch_planes = 2;
+ const int output_rows = 2;
+ const int output_cols = 3;
+ const int output_planes = 4;
+
+ Tensor<float, 5> input(channels, input_planes, input_rows, input_cols,
+ num_batches);
+ Tensor<float, 5> result(channels, output_planes, output_rows, output_cols,
+ num_batches);
+ input = input.constant(11.0f) + input.random();
+ result.setRandom();
+ result = result.constant(-1000.0f);
+
+ // Max pooling using a 4x3x2 window and a stride of 1.
+ const int stride = 1;
+ result = CuboidAvgPooling(input, patch_planes, patch_rows, patch_cols, stride,
+ stride, stride, PADDING_VALID);
+
+ EXPECT_EQ(result.dimension(0), channels);
+ EXPECT_EQ(result.dimension(1), output_planes);
+ EXPECT_EQ(result.dimension(2), output_rows);
+ EXPECT_EQ(result.dimension(3), output_cols);
+ EXPECT_EQ(result.dimension(4), num_batches);
+
+ for (int b = 0; b < num_batches; ++b) {
+ for (int d = 0; d < channels; ++d) {
+ for (int i = 0; i < output_planes; ++i) {
+ for (int j = 0; j < output_rows; ++j) {
+ for (int k = 0; k < output_cols; ++k) {
+ float expected_sum = 0.0f;
+ int expected_count = 0;
+ for (int p = 0; p < patch_planes; ++p) {
+ for (int r = 0; r < patch_rows; ++r) {
+ for (int c = 0; c < patch_cols; ++c) {
+ expected_sum += input(d, p + i, r + j, c + k, b);
+ expected_count++;
+ }
+ }
+ }
+ const float expected = expected_sum / expected_count;
+ if (result(d, i, j, k, b) != expected) {
+ std::cout << "at d=" << d << " b=" << b << " i=" << i
+ << " j=" << j << " k=" << k << " "
+ << result(d, i, j, k, b) << " vs " << expected
+ << std::endl;
+ }
+ EigenApprox(result(d, i, j, k, b), expected);
+ }
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenPoolingTest, ValidCuboidRowMajor) {
+ const int channels = 10;
+ const int input_planes = 5;
+ const int input_rows = 5;
+ const int input_cols = 5;
+ const int num_batches = 13;
+ const int patch_rows = 4;
+ const int patch_cols = 3;
+ const int patch_planes = 2;
+ const int output_rows = 2;
+ const int output_cols = 3;
+ const int output_planes = 4;
+
+ Tensor<float, 5, RowMajor> input(num_batches, input_cols, input_rows,
+ input_planes, channels);
+ Tensor<float, 5, RowMajor> result(num_batches, output_cols, output_rows,
+ output_planes, channels);
+ input = input.constant(11.0f) + input.random();
+ result.setRandom();
+ result = result.constant(-1000.0f);
+
+ // Max pooling using a 4x3x2 window and a stride of 1.
+ const int stride = 1;
+ result = CuboidAvgPooling(input, patch_planes, patch_rows, patch_cols, stride,
+ stride, stride, PADDING_VALID);
+
+ EXPECT_EQ(result.dimension(4), channels);
+ EXPECT_EQ(result.dimension(3), output_planes);
+ EXPECT_EQ(result.dimension(2), output_rows);
+ EXPECT_EQ(result.dimension(1), output_cols);
+ EXPECT_EQ(result.dimension(0), num_batches);
+
+ for (int b = 0; b < num_batches; ++b) {
+ for (int d = 0; d < channels; ++d) {
+ for (int i = 0; i < output_planes; ++i) {
+ for (int j = 0; j < output_rows; ++j) {
+ for (int k = 0; k < output_cols; ++k) {
+ float expected_sum = 0.0f;
+ int expected_count = 0;
+ for (int p = 0; p < patch_planes; ++p) {
+ for (int r = 0; r < patch_rows; ++r) {
+ for (int c = 0; c < patch_cols; ++c) {
+ expected_sum += input(b, c + k, r + j, p + i, d);
+ expected_count++;
+ }
+ }
+ }
+ const float expected = expected_sum / expected_count;
+ if (result(b, k, j, i, d) != expected) {
+ std::cout << "at d=" << d << " b=" << b << " i=" << i
+ << " j=" << j << " k=" << k << " "
+ << result(b, k, j, i, d) << " vs " << expected
+ << std::endl;
+ }
+ EigenApprox(result(b, k, j, i, d), expected);
+ }
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenPoolingTest, SameCuboid) {
+ const int channels = 10;
+ const int input_planes = 5;
+ const int input_rows = 5;
+ const int input_cols = 5;
+ const int num_batches = 13;
+ const int patch_rows = 4;
+ const int patch_cols = 3;
+ const int patch_planes = 2;
+ const int output_rows = input_rows;
+ const int output_cols = input_cols;
+ const int output_planes = input_planes;
+
+ Tensor<float, 5> input(channels, input_planes, input_rows, input_cols,
+ num_batches);
+ Tensor<float, 5> result(channels, output_planes, output_rows, output_cols,
+ num_batches);
+ input = input.constant(11.0f) + input.random();
+ result.setRandom();
+ result = result.constant(-1000.0f);
+
+ // Max pooling using a 4x3x2 window and a stride of 1.
+ const int stride = 1;
+ result = CuboidAvgPooling(input, patch_planes, patch_rows, patch_cols, stride,
+ stride, stride, PADDING_SAME);
+
+ EXPECT_EQ(result.dimension(0), channels);
+ EXPECT_EQ(result.dimension(1), output_planes);
+ EXPECT_EQ(result.dimension(2), output_rows);
+ EXPECT_EQ(result.dimension(3), output_cols);
+ EXPECT_EQ(result.dimension(4), num_batches);
+
+ const int pad_p = output_planes - input_planes + patch_planes - 1;
+ const int pad_r = output_rows - input_rows + patch_rows - 1;
+ const int pad_c = output_cols - input_cols + patch_cols - 1;
+
+ // Number of pixels the input is extended with at the lower end in every
+ // dimension.
+ const int dp = pad_p - pad_p / 2;
+ const int dr = pad_r - pad_r / 2;
+ const int dc = pad_c - pad_c / 2;
+
+ for (int b = 0; b < num_batches; ++b) {
+ for (int d = 0; d < channels; ++d) {
+ for (int i = 0; i < output_planes; ++i) {
+ for (int j = 0; j < output_rows; ++j) {
+ for (int k = 0; k < output_cols; ++k) {
+ float expected_sum = 0.0f;
+ int expected_count = 0;
+ for (int p = 0; p < patch_planes; ++p) {
+ for (int r = 0; r < patch_rows; ++r) {
+ for (int c = 0; c < patch_cols; ++c) {
+ const int in_p = p + i - dp;
+ const int in_r = r + j - dr;
+ const int in_c = c + k - dc;
+ if (in_p >= 0 && in_p < input_planes && in_r >= 0 &&
+ in_r < input_rows && in_c >= 0 && in_c < input_cols) {
+ expected_sum += input(d, in_p, in_r, in_c, b);
+ expected_count++;
+ }
+ }
+ }
+ }
+ const float expected = expected_sum / expected_count;
+ if (result(d, i, j, k, b) != expected) {
+ std::cout << "at d=" << d << " b=" << b << " i=" << i
+ << " j=" << j << " k=" << k << " "
+ << result(d, i, j, k, b) << " vs " << expected
+ << std::endl;
+ }
+ EigenApprox(result(d, i, j, k, b), expected);
+ }
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenPoolingTest, SameCuboidRowMajor) {
+ const int channels = 10;
+ const int input_planes = 5;
+ const int input_rows = 5;
+ const int input_cols = 5;
+ const int num_batches = 13;
+ const int patch_rows = 4;
+ const int patch_cols = 3;
+ const int patch_planes = 2;
+ const int output_rows = input_rows;
+ const int output_cols = input_cols;
+ const int output_planes = input_planes;
+
+ Tensor<float, 5, RowMajor> input(num_batches, input_cols, input_rows,
+ input_planes, channels);
+ Tensor<float, 5, RowMajor> result(num_batches, output_cols, output_rows,
+ output_planes, channels);
+ input = input.constant(11.0f) + input.random();
+ result.setRandom();
+ result = result.constant(-1000.0f);
+
+ // Max pooling using a 4x3x2 window and a stride of 1.
+ const int stride = 1;
+ result = CuboidAvgPooling(input, patch_planes, patch_rows, patch_cols, stride,
+ stride, stride, PADDING_SAME);
+
+ EXPECT_EQ(result.dimension(4), channels);
+ EXPECT_EQ(result.dimension(3), output_planes);
+ EXPECT_EQ(result.dimension(2), output_rows);
+ EXPECT_EQ(result.dimension(1), output_cols);
+ EXPECT_EQ(result.dimension(0), num_batches);
+
+ const int pad_p = output_planes - input_planes + patch_planes - 1;
+ const int pad_r = output_rows - input_rows + patch_rows - 1;
+ const int pad_c = output_cols - input_cols + patch_cols - 1;
+
+ // Number of pixels the input is extended with at the lower end in every
+ // dimension.
+ const int dp = pad_p - pad_p / 2;
+ const int dr = pad_r - pad_r / 2;
+ const int dc = pad_c - pad_c / 2;
+
+ for (int b = 0; b < num_batches; ++b) {
+ for (int d = 0; d < channels; ++d) {
+ for (int i = 0; i < output_planes; ++i) {
+ for (int j = 0; j < output_rows; ++j) {
+ for (int k = 0; k < output_cols; ++k) {
+ float expected_sum = 0.0f;
+ int expected_count = 0;
+ for (int p = 0; p < patch_planes; ++p) {
+ for (int r = 0; r < patch_rows; ++r) {
+ for (int c = 0; c < patch_cols; ++c) {
+ const int in_p = p + i - dp;
+ const int in_r = r + j - dr;
+ const int in_c = c + k - dc;
+ if (in_p >= 0 && in_p < input_planes && in_r >= 0 &&
+ in_r < input_rows && in_c >= 0 && in_c < input_cols) {
+ expected_sum += input(b, in_c, in_r, in_p, d);
+ expected_count++;
+ }
+ }
+ }
+ }
+ const float expected = expected_sum / expected_count;
+ if (result(b, k, j, i, d) != expected) {
+ std::cout << "at d=" << d << " b=" << b << " i=" << i
+ << " j=" << j << " k=" << k << " "
+ << result(b, k, j, i, d) << " vs " << expected
+ << std::endl;
+ }
+ EigenApprox(result(b, k, j, i, d), expected);
+ }
+ }
+ }
+ }
+ }
+}
+
+static void test_strided_max_pooling_layer() {
+ const int depth = 10;
+ const int input_rows = 5;
+ const int input_cols = 5;
+ const int num_batches = 13;
+ const int patch_rows = 3;
+ const int patch_cols = 3;
+ const int output_rows = 2;
+ const int output_cols = 2;
+
+ Tensor<float, 4> input(depth, input_rows, input_cols, num_batches);
+ Tensor<float, 4> result(depth, output_rows, output_cols, num_batches);
+ input = input.constant(11.0f) + input.random();
+ result.setRandom();
+
+ // Max pooling using a 3x3 window and a stride of 2.
+ int stride = 2;
+ result = SpatialMaxPooling(input, patch_rows, patch_cols, stride, stride,
+ PADDING_VALID);
+
+ EXPECT_EQ(result.dimension(0), depth);
+ EXPECT_EQ(result.dimension(1), output_rows);
+ EXPECT_EQ(result.dimension(2), output_cols);
+ EXPECT_EQ(result.dimension(3), num_batches);
+
+ for (int b = 0; b < num_batches; ++b) {
+ for (int d = 0; d < depth; ++d) {
+ for (int i = 0; i < output_rows; ++i) {
+ for (int j = 0; j < output_cols; ++j) {
+ float expected = -10000.f;
+ for (int r = 0; r < patch_rows; ++r) {
+ for (int c = 0; c < patch_cols; ++c) {
+ expected = (std::max)(
+ expected, input(d, r + stride * i, c + stride * j, b));
+ }
+ }
+ if (result(d, i, j, b) != expected) {
+ std::cout << "at d=" << d << " b=" << b << " i=" << i << " j=" << j
+ << " " << result(d, i, j, b) << " vs " << expected
+ << std::endl;
+ }
+ EigenApprox(result(d, i, j, b), expected);
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenPoolingTest, Strided) {
+ const int depth = 10;
+ const int input_rows = 5;
+ const int input_cols = 5;
+ const int num_batches = 13;
+ const int patch_rows = 3;
+ const int patch_cols = 3;
+ const int output_rows = 2;
+ const int output_cols = 2;
+
+ Tensor<float, 4, RowMajor> input(num_batches, input_cols, input_rows, depth);
+ Tensor<float, 4, RowMajor> result(num_batches, output_cols, output_rows,
+ depth);
+ input = input.constant(11.0f) + input.random();
+ result.setRandom();
+
+ // Max pooling using a 3x3 window and a stride of 2.
+ int stride = 2;
+ result = SpatialMaxPooling(input, patch_rows, patch_cols, stride, stride,
+ PADDING_VALID);
+
+ EXPECT_EQ(result.dimension(3), depth);
+ EXPECT_EQ(result.dimension(2), output_rows);
+ EXPECT_EQ(result.dimension(1), output_cols);
+ EXPECT_EQ(result.dimension(0), num_batches);
+
+ for (int b = 0; b < num_batches; ++b) {
+ for (int d = 0; d < depth; ++d) {
+ for (int i = 0; i < output_rows; ++i) {
+ for (int j = 0; j < output_cols; ++j) {
+ float expected = -10000.f;
+ for (int r = 0; r < patch_rows; ++r) {
+ for (int c = 0; c < patch_cols; ++c) {
+ expected = (std::max)(
+ expected, input(b, c + stride * j, r + stride * i, d));
+ }
+ }
+ if (result(b, j, i, d) != expected) {
+ std::cout << "at d=" << d << " b=" << b << " i=" << i << " j=" << j
+ << " " << result(b, j, i, d) << " vs " << expected
+ << std::endl;
+ }
+ EigenApprox(result(b, j, i, d), expected);
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenPoolingTest, StridedCuboid) {
+ const int channels = 10;
+ const int input_planes = 5;
+ const int input_rows = 5;
+ const int input_cols = 5;
+ const int num_batches = 13;
+ const int patch_planes = 3;
+ const int patch_rows = 3;
+ const int patch_cols = 3;
+ const int output_planes = 2;
+ const int output_rows = 2;
+ const int output_cols = 2;
+
+ Tensor<float, 5> input(channels, input_planes, input_rows, input_cols,
+ num_batches);
+ Tensor<float, 5> result(channels, output_planes, output_rows, output_cols,
+ num_batches);
+ input = input.constant(11.0f) + input.random();
+ result.setRandom();
+
+ // Max pooling using a 3x3x3 window and a stride of 2.
+ int stride = 2;
+ result = CuboidMaxPooling(input, patch_planes, patch_rows, patch_cols, stride,
+ stride, stride, PADDING_VALID);
+
+ EXPECT_EQ(result.dimension(0), channels);
+ EXPECT_EQ(result.dimension(1), output_planes);
+ EXPECT_EQ(result.dimension(2), output_rows);
+ EXPECT_EQ(result.dimension(3), output_cols);
+ EXPECT_EQ(result.dimension(4), num_batches);
+
+ for (int b = 0; b < num_batches; ++b) {
+ for (int d = 0; d < channels; ++d) {
+ for (int i = 0; i < output_planes; ++i) {
+ for (int j = 0; j < output_rows; ++j) {
+ for (int k = 0; k < output_cols; ++k) {
+ float expected = -10000.f;
+ for (int p = 0; p < patch_planes; ++p) {
+ for (int r = 0; r < patch_rows; ++r) {
+ for (int c = 0; c < patch_cols; ++c) {
+ expected = (std::max)(expected,
+ input(d, p + stride * i, r + stride * j,
+ c + stride * k, b));
+ }
+ }
+ }
+ if (result(d, i, j, k, b) != expected) {
+ std::cout << "at d=" << d << " b=" << b << " i=" << i
+ << " j=" << j << " " << k << " "
+ << result(d, i, j, k, b) << " vs " << expected
+ << std::endl;
+ }
+ EigenApprox(result(d, i, j, k, b), expected);
+ }
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenPoolingTest, StridedCuboidRowMajor) {
+ const int channels = 10;
+ const int input_planes = 5;
+ const int input_rows = 5;
+ const int input_cols = 5;
+ const int num_batches = 13;
+ const int patch_planes = 3;
+ const int patch_rows = 3;
+ const int patch_cols = 3;
+ const int output_planes = 2;
+ const int output_rows = 2;
+ const int output_cols = 2;
+
+ Tensor<float, 5, RowMajor> input(num_batches, input_cols, input_rows,
+ input_planes, channels);
+ Tensor<float, 5, RowMajor> result(num_batches, output_cols, output_rows,
+ output_planes, channels);
+ input = input.constant(11.0f) + input.random();
+ result.setRandom();
+
+ // Max pooling using a 3x3x3 window and a stride of 2.
+ int stride = 2;
+ result = CuboidMaxPooling(input, patch_planes, patch_rows, patch_cols, stride,
+ stride, stride, PADDING_VALID);
+
+ EXPECT_EQ(result.dimension(4), channels);
+ EXPECT_EQ(result.dimension(3), output_planes);
+ EXPECT_EQ(result.dimension(2), output_rows);
+ EXPECT_EQ(result.dimension(1), output_cols);
+ EXPECT_EQ(result.dimension(0), num_batches);
+
+ for (int b = 0; b < num_batches; ++b) {
+ for (int d = 0; d < channels; ++d) {
+ for (int i = 0; i < output_planes; ++i) {
+ for (int j = 0; j < output_rows; ++j) {
+ for (int k = 0; k < output_cols; ++k) {
+ float expected = -10000.f;
+ for (int p = 0; p < patch_planes; ++p) {
+ for (int r = 0; r < patch_rows; ++r) {
+ for (int c = 0; c < patch_cols; ++c) {
+ expected = (std::max)(expected,
+ input(b, c + stride * k, r + stride * j,
+ p + stride * i, d));
+ }
+ }
+ }
+ if (result(b, k, j, i, d) != expected) {
+ std::cout << "at d=" << d << " b=" << b << " i=" << i
+ << " j=" << j << " " << k << " "
+ << result(b, k, j, i, d) << " vs " << expected
+ << std::endl;
+ }
+ EigenApprox(result(b, k, j, i, d), expected);
+ }
+ }
+ }
+ }
+ }
+}
+
+} // namespace Eigen
diff --git a/tensorflow/core/kernels/eigen_softmax.h b/tensorflow/core/kernels/eigen_softmax.h
new file mode 100644
index 0000000000..49123e8062
--- /dev/null
+++ b/tensorflow/core/kernels/eigen_softmax.h
@@ -0,0 +1,90 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_SOFTMAX_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_SOFTMAX_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace Eigen {
+
+/** SoftMax
+ * \ingroup CXX11_NeuralNetworks_Module
+ *
+ * \brief Applies a softmax
+ *
+ * The input parameter is expected to be a col-major tensor with a rank of 2 (depth and other).
+ *
+ * The result can be assigned to a tensor of rank and dimensions equal to that of the input. The result will be laid out in col-major order.
+ *
+*/
+
+namespace {
+struct SoftmaxOp {
+ SoftmaxOp(const float beta) : beta_(beta) { }
+
+ template <typename Input>
+ typename Input::Dimensions dimensions(const Input& input) const {
+ return input.dimensions();
+ }
+
+ template <typename Input, typename Output, typename Device>
+ void eval(const Input& input, Output& output, const Device& device) const
+ {
+#if !defined(EIGEN_HAS_INDEX_LIST)
+ // nvcc doesn't support cxx11
+ Eigen::array<typename internal::traits<Input>::Index, 1> depth_dim;
+ depth_dim[0] = 0;
+ Eigen::array<typename internal::traits<Input>::Index, 2> bcast;
+ bcast[0] = dimensions(input)[0];
+ bcast[1] = 1;
+ DSizes<typename internal::traits<Input>::Index, 2> dims2d;
+ dims2d[0] = 1;
+ dims2d[1] = dimensions(input)[1];
+#else
+ // Take advantage of cxx11 to give the compiler information it can use to
+ // optimize the code.
+ Eigen::IndexList<Eigen::type2index<0>> depth_dim;
+ Eigen::IndexList<int, Eigen::type2index<1>> bcast;
+ bcast.set(0, dimensions(input)[0]);
+ Eigen::IndexList<Eigen::type2index<1>, typename internal::traits<Input>::Index> dims2d;
+ dims2d.set(1, dimensions(input)[1]);
+#endif
+
+ output.device(device) = ((input - input.maximum(depth_dim).eval().reshape(dims2d).broadcast(bcast)) * beta_).exp();
+ output.device(device) = output / (output.sum(depth_dim).eval().reshape(dims2d).broadcast(bcast));
+ }
+
+ private:
+ const float beta_;
+};
+}
+
+
+template <typename Input>
+EIGEN_ALWAYS_INLINE
+static const TensorCustomUnaryOp<const SoftmaxOp, const Input>
+SoftMax(const Input& input, const float beta)
+{
+ EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout == ColMajor, YOU_MADE_A_PROGRAMMING_MISTAKE);
+ EIGEN_STATIC_ASSERT(internal::traits<Input>::NumDimensions == 2, YOU_MADE_A_PROGRAMMING_MISTAKE);
+
+ const SoftmaxOp op(beta);
+ return input.customOp(op);
+}
+
+} // end namespace Eigen
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_SOFTMAX_H_
diff --git a/tensorflow/core/kernels/eigen_softmax_test.cc b/tensorflow/core/kernels/eigen_softmax_test.cc
new file mode 100644
index 0000000000..8623861518
--- /dev/null
+++ b/tensorflow/core/kernels/eigen_softmax_test.cc
@@ -0,0 +1,65 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/eigen_softmax.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace Eigen {
+
+namespace {
+void EigenApprox(float a, float b) {
+ ASSERT_TRUE(std::abs(a - b) <= std::min(std::abs(a), std::abs(b)) * 1e-3);
+}
+}
+
+TEST(EigenSoftmaxTest, Simple) {
+ const int depth = 1024;
+ const int batch = 32;
+ const float beta = 1.2f;
+
+ Tensor<float, 2> input(depth, batch);
+ input = input.constant(11.0f) + input.random();
+
+ Tensor<float, 2> reference(depth, batch);
+ reference.setRandom();
+
+ Eigen::array<int, 1> depth_dim;
+ depth_dim[0] = 0;
+ Eigen::array<int, 2> bcast;
+ bcast[0] = depth;
+ bcast[1] = 1;
+ Tensor<float, 2>::Dimensions dims2d;
+ dims2d[0] = 1;
+ dims2d[1] = batch;
+ reference =
+ ((input -
+ input.maximum(depth_dim).eval().reshape(dims2d).broadcast(bcast)) *
+ beta)
+ .exp();
+ reference =
+ reference /
+ (reference.sum(depth_dim).eval().reshape(dims2d).broadcast(bcast));
+
+ Tensor<float, 2> result = SoftMax(input, beta);
+
+ for (int i = 0; i < depth; ++i) {
+ for (int j = 0; j < batch; ++j) {
+ EigenApprox(result(i, j), reference(i, j));
+ }
+ }
+}
+
+} // namespace Eigen
diff --git a/tensorflow/core/kernels/eigen_spatial_convolutions.h b/tensorflow/core/kernels/eigen_spatial_convolutions.h
new file mode 100644
index 0000000000..53a3e99b19
--- /dev/null
+++ b/tensorflow/core/kernels/eigen_spatial_convolutions.h
@@ -0,0 +1,785 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace Eigen {
+
+namespace internal {
+
+// These optimizations require vector instructions
+#ifdef EIGEN_VECTORIZE
+
+// TODO: Consolidate this part of the code with the image patch extraction code
+// since they are both very similar.
+template <typename NewDimension, DenseIndex Rows, DenseIndex Cols, typename ArgType, typename Device,
+ typename Scalar_, typename Index,
+ typename nocontract_t, typename contract_t,
+ int Side, size_t packet_size,
+ bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
+class TensorContractionInputMapper<Scalar_, Index, Side, TensorEvaluator<const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, Device>, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>
+{
+ public:
+ typedef Scalar_ Scalar;
+ typedef TensorContractionInputMapper<Scalar, Index, Side, TensorEvaluator<const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, Device>, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Self;
+ typedef TensorContractionSubMapper<Scalar, Index, Side, TensorEvaluator<const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, Device>, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper;
+ typedef SubMapper VectorMapper;
+ typedef SubMapper LinearMapper;
+ typedef typename packet_traits<Scalar>::type Packet;
+
+ TensorContractionInputMapper(const TensorEvaluator<const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, Device>& tensor,
+ const nocontract_t&, const nocontract_t&,
+ const contract_t&, const contract_t&)
+ : m_impl(tensor.impl().impl())
+ {
+ Index patch_rows;
+ Index patch_depth;
+ if (internal::traits<ArgType>::Layout == ColMajor) {
+ patch_depth = tensor.impl().dimensions()[0];
+ patch_rows = tensor.impl().dimensions()[1];
+ m_patch_cols = tensor.impl().dimensions()[2];
+ m_num_patches = tensor.impl().dimensions()[3];
+ } else {
+ static const int NumDims = tensor.impl().dimensions().size();
+ patch_depth = tensor.impl().dimensions()[NumDims - 1];
+ patch_rows = tensor.impl().dimensions()[NumDims - 2];
+ m_patch_cols = tensor.impl().dimensions()[NumDims - 3];
+ m_num_patches = tensor.impl().dimensions()[NumDims - 4];
+ }
+ m_patch_row_inflate_strides = tensor.impl().rowInflateStride();
+ m_patch_col_inflate_strides = tensor.impl().colInflateStride();
+
+ m_colStride = patch_rows;
+
+ m_outputRows = tensor.impl().outputRows();
+ m_row_strides = tensor.impl().userRowStride();
+ m_col_strides = tensor.impl().userColStride();
+
+ m_in_row_strides = tensor.impl().userInRowStride();
+ m_in_col_strides = tensor.impl().userInColStride();
+
+ if (internal::traits<ArgType>::Layout == ColMajor) {
+ m_inputRows = tensor.impl().impl().dimensions()[1];
+ m_inputCols = tensor.impl().impl().dimensions()[2];
+ } else {
+ static const int NumDims = tensor.impl().impl().dimensions().size();
+ m_inputRows = tensor.impl().impl().dimensions()[NumDims - 2];
+ m_inputCols = tensor.impl().impl().dimensions()[NumDims - 3];
+ }
+
+ m_rowInputStride = patch_depth;
+ m_colInputStride = patch_depth * m_inputRows;
+ m_patchInputStride = patch_depth * m_inputRows * m_inputCols;
+
+ m_rowPaddingTop = tensor.impl().rowPaddingTop();
+ m_colPaddingLeft = tensor.impl().colPaddingLeft();
+
+ m_fastInputRowStride = internal::TensorIntDivisor<Index>(m_patch_row_inflate_strides);
+ m_fastInputColStride = internal::TensorIntDivisor<Index>(m_patch_col_inflate_strides);
+ m_fastNumPatches = internal::TensorIntDivisor<Index>(m_num_patches);
+ m_fastColStride = internal::TensorIntDivisor<Index>(m_colStride);
+ m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows);
+ m_fastDimZero = internal::TensorIntDivisor<Index>(patch_depth);
+ }
+
+ TensorContractionInputMapper(const TensorContractionInputMapper& base_mapper) :
+ m_impl(base_mapper.m_impl) {
+ m_patch_cols = base_mapper.m_patch_cols;
+ m_num_patches = base_mapper.m_num_patches;
+ m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides;
+ m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides;
+
+ m_colStride = base_mapper.m_colStride;
+
+ m_rowInputStride = base_mapper.m_rowInputStride;
+ m_colInputStride = base_mapper.m_colInputStride;
+ m_patchInputStride = base_mapper.m_patchInputStride;
+
+ m_inputRows = base_mapper.m_inputRows;
+ m_inputCols = base_mapper.m_inputCols;
+
+ m_outputRows = base_mapper.m_outputRows;
+ m_row_strides = base_mapper.m_row_strides;
+ m_col_strides = base_mapper.m_col_strides;
+
+ m_in_row_strides = base_mapper.m_in_row_strides;
+ m_in_col_strides = base_mapper.m_in_col_strides;
+
+ m_rowPaddingTop = base_mapper.m_rowPaddingTop;
+ m_colPaddingLeft = base_mapper.m_colPaddingLeft;
+
+ m_fastInputRowStride = base_mapper.m_fastInputRowStride;
+ m_fastInputColStride = base_mapper.m_fastInputColStride;
+ m_fastNumPatches = base_mapper.m_fastNumPatches;
+ m_fastColStride = base_mapper.m_fastColStride;
+ m_fastOutputRows = base_mapper.m_fastOutputRows;
+ m_fastDimZero = base_mapper.m_fastDimZero;
+ }
+
+ // If true, turns off some optimizations for loading packets since the image
+ // patches are "non-standard" such as there are non-trivial strides or
+ // inflations in the input.
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE bool nonStandardPatches() const {
+ return m_in_row_strides != 1 || m_in_col_strides != 1 || m_patch_row_inflate_strides != 1 || m_patch_col_inflate_strides != 1;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const {
+ return SubMapper(*this, i, j);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
+ return LinearMapper(*this, i, j);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Scalar operator()(Index row) const {
+ Index rowIndex, colIndex, otherIndex;
+ computeBaseIndices(0, rowIndex, colIndex, otherIndex);
+ return loadCoeff(row, rowIndex, colIndex, otherIndex);
+ }
+
+ // Load the coefficient at the patchIndex location instead of the usual m_rowIndex,
+ // m_colIndex, m_otherIndex. This is currently only used by the gpu code. EIGEN_DEVICE_FUNC
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Scalar operator()(Index row, Index patchIndex) const {
+ Index rowIndex, colIndex, otherIndex;
+ computeBaseIndices(patchIndex, rowIndex, colIndex, otherIndex);
+ return loadCoeff(row, rowIndex, colIndex, otherIndex);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet loadPacket(Index row) const {
+ Index rowIndex, colIndex, otherIndex;
+ computeBaseIndices(0, rowIndex, colIndex, otherIndex);
+ return loadPacket(row, rowIndex, colIndex, otherIndex);
+ }
+
+ // Load the packet at the patchIndex location instead of the usual m_rowIndex,
+ // m_colIndex, m_otherIndex. This is currently only used by the gpu code.
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet loadPacket(Index row, Index patchIndex) const {
+ Index rowIndex, colIndex, otherIndex;
+ computeBaseIndices(patchIndex, rowIndex, colIndex, otherIndex);
+ return loadPacket(row, rowIndex, colIndex, otherIndex);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE const TensorEvaluator<ArgType, Device>& impl() const { return m_impl; }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_rowInputStride; }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchRows() const { return m_colStride; }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchCols() const { return m_patch_cols; }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth, const Index baseIndex) const {
+ const Index inputIndex = depth + baseIndex;
+ return m_impl.template packet<Unaligned>(inputIndex);
+ }
+
+ private:
+ friend class TensorContractionSubMapper<Scalar, Index, Side, TensorEvaluator<const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, Device>, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>;
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index rowIndex, Index colIndex, Index otherIndex) const {
+ // Find the offset of the element wrt the location of the first element.
+ const Index patchOffset = patchId / m_fastDimZero;
+
+ const Index colOffset = patchOffset / m_fastColStride;
+ const Index inputCol = colIndex + colOffset * m_in_col_strides;
+ const Index origInputCol = (m_patch_col_inflate_strides == 1) ? inputCol : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0);
+ const Index rowOffset = patchOffset - colOffset * m_colStride;
+ const Index inputRow = rowIndex + rowOffset * m_in_row_strides;
+ const Index origInputRow = (m_patch_row_inflate_strides == 1) ? inputRow : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0);
+ if (origInputCol < 0 || origInputRow < 0 || origInputCol >= m_inputCols ||
+ origInputRow >= m_inputRows ||
+ (inputCol != origInputCol * m_patch_col_inflate_strides) ||
+ (inputRow != origInputRow * m_patch_row_inflate_strides)) {
+ return Scalar(0);
+ }
+ const Index depth = patchId - patchOffset * patchDepth();
+ const Index inputIndex = depth + origInputRow * m_rowInputStride + origInputCol * m_colInputStride + otherIndex;
+ return m_impl.coeff(inputIndex);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index rowIndex, Index colIndex, Index otherIndex) const {
+ eigen_assert(!nonStandardPatches());
+
+ // Find the offset of the element wrt the location of the first element.
+ const Index patchOffset = patchId / m_fastDimZero;
+
+ const Index colOffset = patchOffset / m_fastColStride;
+ const Index inputCol = colIndex + colOffset;
+ const Index rowOffset = patchOffset - colOffset * m_colStride;
+ const Index inputRow = rowIndex + rowOffset;
+ if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 || inputRow >= m_inputRows) {
+ return Scalar(0);
+ }
+ const Index depth = patchId - patchOffset * patchDepth();
+ const Index inputIndex = depth + inputRow * m_rowInputStride + inputCol * m_colInputStride + otherIndex;
+ return m_impl.coeff(inputIndex);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index rowIndex, Index colIndex, Index otherIndex) const {
+ const Index packetSize = internal::unpacket_traits<Packet>::size;
+ EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
+ eigen_assert(patchId < patchDepth()*patchRows()*m_patch_cols);
+
+ if (nonStandardPatches()) {
+ return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
+ }
+ return loadPacketStandard(patchId, rowIndex, colIndex, otherIndex);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet loadPacketStandard(Index patchId, Index rowIndex, Index colIndex, Index otherIndex) const {
+ const Index packetSize = internal::unpacket_traits<Packet>::size;
+ EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
+ eigen_assert(patchId < patchDepth()*patchRows()*m_patch_cols);
+
+ eigen_assert(!nonStandardPatches());
+
+ if ((patchDepth() % packetSize) == 0) {
+ return loadPacketFast(patchId, rowIndex, colIndex, otherIndex);
+ }
+ else {
+ const Index patchOffsets[2] = {patchId / m_fastDimZero, (patchId + packetSize - 1) / m_fastDimZero};
+
+ const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride, patchOffsets[1] / m_fastColStride};
+
+ const Index inputCols[2] = {colIndex + colOffsets[0], colIndex + colOffsets[1]};
+ if (inputCols[0] >= m_inputCols || inputCols[1] < 0) {
+ // all zeros
+ return internal::pset1<Packet>(Scalar(0));
+ }
+
+ if (inputCols[0] == inputCols[1]) {
+ const Index rowOffsets[2] = {patchOffsets[0] - colOffsets[0]*m_colStride, patchOffsets[1] - colOffsets[1]*m_colStride};
+ eigen_assert(rowOffsets[0] <= rowOffsets[1]);
+ const Index inputRows[2] = {rowIndex + rowOffsets[0], rowIndex + rowOffsets[1]};
+
+ if (inputRows[0] >= m_inputRows || inputRows[1] < 0) {
+ // all zeros
+ return internal::pset1<Packet>(Scalar(0));
+ }
+
+ if (inputRows[0] >= 0 && inputRows[1] < m_inputRows) {
+ // no padding
+ const Index depth = patchId - patchOffsets[0] * patchDepth();
+ const Index inputIndex = depth + inputRows[0] * m_rowInputStride + inputCols[0] * m_colInputStride + otherIndex;
+ return m_impl.template packet<Unaligned>(inputIndex);
+ }
+ }
+ }
+ return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index patchId, Index rowIndex, Index colIndex, Index otherIndex) const {
+ const Index packetSize = internal::unpacket_traits<Packet>::size;
+ EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
+ eigen_assert(patchId < patchDepth()*patchRows()*m_patch_cols);
+
+ eigen_assert(!nonStandardPatches());
+ eigen_assert((patchDepth() % packetSize) == 0);
+ // Find the offset of the element wrt the location of the first element.
+ const Index patchOffset = patchId / m_fastDimZero;
+ eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset);
+
+ const Index colOffset = patchOffset / m_fastColStride;
+ const Index inputCol = colIndex + colOffset;
+ const Index rowOffset = patchOffset - colOffset*m_colStride;
+ const Index inputRow = rowIndex + rowOffset;
+ if (inputCol < 0 || inputRow < 0 || inputCol >= m_inputCols ||
+ inputRow >= m_inputRows) {
+ // all zeros
+ return internal::pset1<Packet>(Scalar(0));
+ }
+ // no padding
+ const Index depth = patchId - patchOffset * patchDepth();
+ const Index inputIndex = depth + inputRow * m_rowInputStride + inputCol * m_colInputStride + otherIndex;
+ return m_impl.template packet<Unaligned>(inputIndex);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet packetWithPossibleZero(Index patchId, Index rowIndex, Index colIndex, Index otherIndex) const
+ {
+ const int packetSize = internal::unpacket_traits<Packet>::size;
+ EIGEN_ALIGN_MAX typename internal::remove_const<Scalar>::type values[packetSize];
+ for (int i = 0; i < packetSize; ++i) {
+ values[i] = loadCoeff(patchId+i, rowIndex, colIndex, otherIndex);
+ }
+ Packet rslt = internal::pload<Packet>(values);
+ return rslt;
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void computeBaseIndices(Index patchIndex, Index& rowIndex, Index& colIndex, Index& otherIndex) const {
+ const int NumInputDims = array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
+ otherIndex = (NumInputDims == 3) ? 0 : patchIndex / m_fastNumPatches;
+ const Index patch2DIndex = (NumInputDims == 3) ? patchIndex : (patchIndex - otherIndex * m_num_patches);
+ otherIndex *= m_patchInputStride;
+ colIndex = patch2DIndex / m_fastOutputRows;
+ rowIndex = patch2DIndex - colIndex * m_outputRows;
+ colIndex = colIndex * m_col_strides - m_colPaddingLeft;
+ rowIndex = rowIndex * m_row_strides - m_rowPaddingTop;
+ }
+
+ Index m_patch_cols; // number of colums in the patch
+ Index m_num_patches; // number of patches to extract.
+ Index m_patch_row_inflate_strides; // the strides for row inflation in the image patch
+ Index m_patch_col_inflate_strides; // the strides for col inflation in the image patch
+ // Fast representation of inflation strides.
+ internal::TensorIntDivisor<Index> m_fastInputRowStride;
+ internal::TensorIntDivisor<Index> m_fastInputColStride;
+
+ Index m_otherStride;
+ Index m_colStride;
+ internal::TensorIntDivisor<Index> m_fastNumPatches;
+ internal::TensorIntDivisor<Index> m_fastColStride;
+
+ Index m_rowInputStride; // row stride in the input tensor
+ Index m_colInputStride; // col stride in the input tensor
+ Index m_patchInputStride; // patch stride in the input tensor
+
+ Index m_inputRows; // Number of rows in the input tensor
+ Index m_inputCols; // Number of cols in the input tensor
+
+ Index m_outputRows; // Number of patch rows
+
+ Index m_row_strides; // User specified row stride
+ Index m_col_strides; // User specified col stride
+
+ Index m_in_row_strides; // User specified input row stride
+ Index m_in_col_strides; // User specified input col stride
+
+ Index m_rowPaddingTop; // Row padding
+ Index m_colPaddingLeft; // Column padding
+
+ internal::TensorIntDivisor<Index> m_fastOutputRows;
+ internal::TensorIntDivisor<Index> m_fastDimZero;
+
+ const TensorEvaluator<ArgType, Device> m_impl;
+};
+
+
+template <typename NewDimension, DenseIndex Rows, DenseIndex Cols, typename ArgType, typename Device,
+ typename Scalar, typename Index,
+ typename nocontract_t, typename contract_t,
+ int Side, size_t packet_size,
+ bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
+class TensorContractionSubMapper<Scalar, Index, Side, TensorEvaluator<const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, Device>, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>
+{
+ public:
+ typedef typename packet_traits<Scalar>::type Packet;
+ typedef typename packet_traits<Scalar>::half HalfPacket;
+
+ typedef TensorContractionInputMapper<Scalar, Index, Side, TensorEvaluator<const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, Device>, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> ParentMapper;
+ typedef TensorContractionSubMapper<Scalar, Index, Side, TensorEvaluator<const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, Device>, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Self;
+ typedef Self LinearMapper;
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset)
+ : m_base_mapper(base_mapper), m_depth_offset(vert_offset), m_col_offset(horiz_offset) {
+ m_base_mapper.computeBaseIndices(m_col_offset, m_rowIndex, m_colIndex, m_otherIndex);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(const Self& base_mapper, Index vert_offset, Index horiz_offset)
+ : m_base_mapper(base_mapper.m_base_mapper), m_depth_offset(vert_offset+base_mapper.m_depth_offset), m_col_offset(horiz_offset+base_mapper.m_col_offset) {
+ m_base_mapper.computeBaseIndices(m_col_offset, m_rowIndex, m_colIndex, m_otherIndex);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
+ return m_base_mapper.loadCoeff(i + m_depth_offset, m_rowIndex, m_colIndex, m_otherIndex);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, Index j) const {
+ return m_base_mapper(i + m_depth_offset, j + m_col_offset);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
+ return m_base_mapper.loadPacket(i + m_depth_offset, m_rowIndex, m_colIndex, m_otherIndex);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const {
+ return m_base_mapper.template loadPacket(i + m_depth_offset, j + m_col_offset);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar loadCoeffStandard(Index i) const {
+ return m_base_mapper.loadCoeffStandard(i + m_depth_offset, m_rowIndex, m_colIndex, m_otherIndex);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index i) const {
+ return m_base_mapper.loadPacketFast(i + m_depth_offset, m_rowIndex, m_colIndex, m_otherIndex);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketStandard(Index i) const {
+ return m_base_mapper.loadPacketStandard(i + m_depth_offset, m_rowIndex, m_colIndex, m_otherIndex);
+ }
+ template <typename Packet>
+ EIGEN_DEVICE_FUNC bool aligned(Index) const {
+ return false;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE bool nonStandardPatches() const {
+ return m_base_mapper.nonStandardPatches();
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_base_mapper.m_rowInputStride; }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchRows() const { return m_base_mapper.m_colStride; }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchCols() const { return m_base_mapper.m_patch_cols; }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth, const Index baseIndex) const {
+ const Index inputIndex = depth + baseIndex;
+ return m_base_mapper.m_impl.template packet<Unaligned>(inputIndex);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE bool padRow(const Index row) const {
+ const Index r = m_rowIndex + row;
+ return r < 0 || r >= m_base_mapper.m_inputRows;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE bool padCol(const Index col) const {
+ const Index c = m_colIndex + col;
+ return c < 0 || c >= m_base_mapper.m_inputCols;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index baseIndex(const Index row, const Index col) const {
+ const Index r = m_rowIndex + row;
+ const Index c = m_colIndex + col;
+ return r * m_base_mapper.m_rowInputStride + c * m_base_mapper.m_colInputStride + m_otherIndex;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index rowOffset() const {
+ const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
+ const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
+ return patchOffset-colOffset*m_base_mapper.m_colStride;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index colOffset() const {
+ const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
+ const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
+ return colOffset;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index depthOffset() const {
+ const Index patchOffset = m_depth_offset % m_base_mapper.patchDepth();
+ return patchOffset;
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
+ return LinearMapper(m_base_mapper, i + m_depth_offset, j + m_col_offset);
+ }
+
+ private:
+ const ParentMapper& m_base_mapper; // that was a reference before
+ Index m_depth_offset; // First row in the input matrix
+ Index m_col_offset; // First col in the input matrix
+
+ Index m_rowIndex; // precomputed row index corresponding to the col offset
+ Index m_colIndex; // precomputed col index corresponding to the col offset
+ Index m_otherIndex; // precomputed other index corresponding to the col offset
+};
+
+
+template <typename NewDimension, DenseIndex Rows, DenseIndex Cols, typename ArgType, typename Device,
+ typename Scalar, typename Index,
+ typename nocontract_t, typename contract_t,
+ int Side, size_t packet_size,
+ bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment, int nr>
+struct gemm_pack_rhs<Scalar, Index, TensorContractionSubMapper<Scalar, Index, Side, TensorEvaluator<const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, Device>, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>, nr, ColMajor, false, false> {
+
+ typedef TensorContractionSubMapper<Scalar, Index, Side, TensorEvaluator<const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, Device>, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper;
+ typedef SubMapper DataMapper;
+
+ static inline Index ceil_div(Index a, Index b) {
+ return (a + b - 1) / b;
+ }
+
+ EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0) const {
+ eigen_assert(stride == 0);
+ eigen_assert(offset == 0);
+
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
+ typedef typename DataMapper::LinearMapper LinearMapper;
+ typedef typename packet_traits<Scalar>::type Packet;
+
+ const Index packet_cols4 = (cols/4) * 4;
+ const Index peeled_k = (depth/packet_size) * packet_size;
+ const bool non_standard_patches = rhs.nonStandardPatches();
+
+ for(Index j2=0; j2<packet_cols4; j2+=4)
+ {
+ const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
+ const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
+ const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
+ const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
+
+ Index k=0;
+ if((packet_size%4)==0 && !non_standard_patches)
+ {
+ const Index patch_depth = rhs.patchDepth();
+ if ((patch_depth % packet_size) == 0) {
+ const Index patch_cols = rhs.patchCols();
+ const Index patch_rows = rhs.patchRows();
+
+ const Index startCol = rhs.colOffset();
+ const Index max_cols = std::min<Index>(ceil_div(peeled_k, patch_rows*patch_depth)+startCol, patch_cols);
+
+ for (Index c = startCol; c < max_cols; ++c) {
+ eigen_assert(k < peeled_k);
+ const Index startRow = (c == startCol) ? rhs.rowOffset() : 0;
+ const Index max_rows = std::min<Index>(ceil_div(peeled_k-c*patch_rows*patch_depth, patch_depth)+startRow, patch_rows);
+
+ const bool pad_col0 = dm0.padCol(c);
+ const bool pad_col1 = dm1.padCol(c);
+ const bool pad_col2 = dm2.padCol(c);
+ const bool pad_col3 = dm3.padCol(c);
+ for (Index r = startRow; r < max_rows; ++r) {
+ eigen_assert(k < peeled_k);
+ const bool pad0 = pad_col0 || dm0.padRow(r);
+ const bool pad1 = pad_col1 || dm1.padRow(r);
+ const bool pad2 = pad_col2 || dm2.padRow(r);
+ const bool pad3 = pad_col3 || dm3.padRow(r);
+
+ const Index idx0 = dm0.baseIndex(r, c);
+ const Index idx1 = dm1.baseIndex(r, c);
+ const Index idx2 = dm2.baseIndex(r, c);
+ const Index idx3 = dm3.baseIndex(r, c);
+
+ const Index startDepth = ((c == startCol) && (r == startRow)) ? rhs.depthOffset() : 0;
+ const Index max_depth = std::min<Index>(peeled_k-c*patch_rows*patch_depth-r*patch_depth+startDepth, patch_depth);
+ eigen_assert(max_depth % packet_size == 0);
+ for (Index d = startDepth; d < max_depth; d += packet_size) {
+ eigen_assert(k < peeled_k);
+ PacketBlock<Packet, 4> kernel;
+ kernel.packet[0] = pad0 ? pset1<Packet>(0) : rhs.packetNoPadding(d, idx0);
+ kernel.packet[1] = pad1 ? pset1<Packet>(0) : rhs.packetNoPadding(d, idx1);
+ kernel.packet[2] = pad2 ? pset1<Packet>(0) : rhs.packetNoPadding(d, idx2);
+ kernel.packet[3] = pad3 ? pset1<Packet>(0) : rhs.packetNoPadding(d, idx3);
+ ptranspose(kernel);
+ pstoreu(block+0*packet_size, kernel.packet[0]);
+ pstoreu(block+1*packet_size, kernel.packet[1]);
+ pstoreu(block+2*packet_size, kernel.packet[2]);
+ pstoreu(block+3*packet_size, kernel.packet[3]);
+ block+=4*packet_size;
+ k += packet_size;
+ }
+ }
+ }
+
+ for(; k<peeled_k; k+=packet_size) {
+ PacketBlock<Packet, 4> kernel;
+ kernel.packet[0] = dm0.loadPacketFast(k);
+ kernel.packet[1] = dm1.loadPacketFast(k);
+ kernel.packet[2] = dm2.loadPacketFast(k);
+ kernel.packet[3] = dm3.loadPacketFast(k);
+ ptranspose(kernel);
+ pstoreu(block+0*packet_size, kernel.packet[0]);
+ pstoreu(block+1*packet_size, kernel.packet[1]);
+ pstoreu(block+2*packet_size, kernel.packet[2]);
+ pstoreu(block+3*packet_size, kernel.packet[3]);
+ block+=4*packet_size;
+ }
+ }
+ else {
+ for(; k<peeled_k; k+=packet_size) {
+ PacketBlock<Packet, 4> kernel;
+ kernel.packet[0] = dm0.loadPacketStandard(k);
+ kernel.packet[1] = dm1.loadPacketStandard(k);
+ kernel.packet[2] = dm2.loadPacketStandard(k);
+ kernel.packet[3] = dm3.loadPacketStandard(k);
+ ptranspose(kernel);
+ pstoreu(block+0*packet_size, kernel.packet[0]);
+ pstoreu(block+1*packet_size, kernel.packet[1]);
+ pstoreu(block+2*packet_size, kernel.packet[2]);
+ pstoreu(block+3*packet_size, kernel.packet[3]);
+ block+=4*packet_size;
+ }
+ }
+ }
+ if (!rhs.nonStandardPatches()) {
+ for(; k<depth; k++)
+ {
+ block[0] = dm0.loadCoeffStandard(k);
+ block[1] = dm1.loadCoeffStandard(k);
+ block[2] = dm2.loadCoeffStandard(k);
+ block[3] = dm3.loadCoeffStandard(k);
+ block += 4;
+ }
+ }
+ else {
+ for(; k<depth; k++)
+ {
+ block[0] = dm0(k);
+ block[1] = dm1(k);
+ block[2] = dm2(k);
+ block[3] = dm3(k);
+ block += 4;
+ }
+ }
+ }
+
+ // copy the remaining columns one at a time (nr==1)
+ for(Index j2=packet_cols4; j2<cols; ++j2)
+ {
+ const SubMapper dm0 = rhs.getLinearMapper(0, j2);
+ for(Index k=0; k<depth; k++)
+ {
+ *block = dm0(k);
+ block += 1;
+ }
+ }
+ }
+};
+
+#endif // EIGEN_VECTORIZE
+} // end namespace internal
+
+
+/** SpatialConvolution
+ * \ingroup CXX11_NeuralNetworks_Module
+ *
+ * \brief Applies a 2D convolution over a multichannel input image.
+ *
+ * The input parameter is expected to be a tensor with a rank of 3 or more (channels, height, width, and optionally others)
+ * The kernel parameter is expected to be a 4D tensor (filters, channels, kernel_height, kernel_width)
+ * The input and the kernel must both be in col-major layout. The result will also be in col-major layout.
+ *
+ * If in_stride > 1, then applies convolution with holes (aka atrous convolution), sampling every in_stride input pixels.
+ *
+ * The result can be assigned to a tensor of rank equal to the rank of the input. The dimensions of the result will be filters, height, width (and others if applicable).
+ *
+ * It is possible to swap the order of the width and height dimensions provided that the same order is used in the input, the kernel, and the output.
+ *
+ */
+template <typename Input, typename Kernel>
+EIGEN_ALWAYS_INLINE
+static const typename internal::conditional<
+ internal::traits<Input>::Layout == ColMajor,
+ TensorReshapingOp<const DSizes<typename internal::traits<Input>::Index, internal::traits<Input>::NumDimensions>, const TensorContractionOp<const array<IndexPair<typename internal::traits<Input>::Index>, 1>, const TensorReshapingOp<const DSizes<typename internal::traits<Input>::Index, 2>, const Kernel>, const TensorReshapingOp<const DSizes<typename internal::traits<Input>::Index, 2>, const TensorImagePatchOp<Dynamic, Dynamic, const Input> > > >,
+ TensorReshapingOp<const DSizes<typename internal::traits<Input>::Index, internal::traits<Input>::NumDimensions>, const TensorContractionOp<const array<IndexPair<typename internal::traits<Input>::Index>, 1>, const TensorReshapingOp<const DSizes<typename internal::traits<Input>::Index, 2>, const TensorImagePatchOp<Dynamic, Dynamic, const Input> >, const TensorReshapingOp<const DSizes<typename internal::traits<Input>::Index, 2>, const Kernel> > > >::type
+SpatialConvolution(const Input& input, const Kernel& kernel, const DenseIndex stride = 1, const PaddingType padding_type = PADDING_SAME, const DenseIndex in_stride = 1) {
+
+ typedef typename internal::traits<Input>::Index TensorIndex;
+ TensorRef<Tensor<typename internal::traits<Input>::Scalar, internal::traits<Input>::NumDimensions, internal::traits<Input>::Layout, TensorIndex> > in(input);
+ TensorRef<Tensor<typename internal::traits<Kernel>::Scalar, internal::traits<Kernel>::NumDimensions, internal::traits<Kernel>::Layout, TensorIndex> > kern(kernel);
+
+ EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout == internal::traits<Kernel>::Layout, YOU_MADE_A_PROGRAMMING_MISTAKE);
+ static const bool isColMajor = (internal::traits<Input>::Layout == ColMajor);
+
+ static const int NumDims = internal::traits<Input>::NumDimensions;
+
+ // Number of filters to apply. This is the same as the output depth of the result
+ const TensorIndex kernelFilters = isColMajor ? kern.dimensions()[0] : kern.dimensions()[3];
+ // Number of channels. This is the same as the input depth.
+ const TensorIndex kernelChannels = isColMajor ? kern.dimensions()[1] : kern.dimensions()[2];
+ const TensorIndex kernelRows = isColMajor ? kern.dimensions()[2] : kern.dimensions()[1];
+ const TensorIndex kernelCols = isColMajor ? kern.dimensions()[3] : kern.dimensions()[0];
+
+ const DenseIndex kernelRowsEff = kernelRows + (kernelRows - 1) * (in_stride - 1);
+ const DenseIndex kernelColsEff = kernelCols + (kernelCols - 1) * (in_stride - 1);
+
+ array<IndexPair<TensorIndex>, 1> contract_dims;
+ contract_dims[0] = IndexPair<TensorIndex>(1, 0);
+
+ const TensorIndex InputRows = isColMajor ? in.dimension(1) : in.dimension(NumDims - 2);
+ const TensorIndex InputCols = isColMajor ? in.dimension(2) : in.dimension(NumDims - 3);
+
+ TensorIndex out_height;
+ TensorIndex out_width;
+ switch (padding_type) {
+ case PADDING_VALID:
+ out_height = numext::ceil((InputRows - kernelRowsEff + 1.f) / static_cast<float>(stride));
+ out_width = numext::ceil((InputCols - kernelColsEff + 1.f) / static_cast<float>(stride));
+ break;
+ case PADDING_SAME:
+ out_height = numext::ceil(InputRows / static_cast<float>(stride));
+ out_width = numext::ceil(InputCols / static_cast<float>(stride));
+ break;
+ default:
+ eigen_assert(false && "unexpected padding");
+ }
+
+ // Molds the output of the patch extraction code into a 2d tensor:
+ // - the first dimension (dims[0]): the patch values to be multiplied with the kernels
+ // - the second dimension (dims[1]): everything else
+ DSizes<TensorIndex, 2> pre_contract_dims;
+ if (isColMajor) {
+ pre_contract_dims[0] = kernelChannels * kernelRows * kernelCols;
+ pre_contract_dims[1] = out_height * out_width;
+ for (int i = 3; i < NumDims; ++i) {
+ pre_contract_dims[1] *= in.dimension(i);
+ }
+ } else {
+ pre_contract_dims[1] = kernelChannels * kernelRows * kernelCols;
+ pre_contract_dims[0] = out_height * out_width;
+ for (int i = 0; i < NumDims - 3; ++i) {
+ pre_contract_dims[0] *= in.dimension(i);
+ }
+ }
+
+ // Molds the output of the contraction into the shape expected by the used
+ // (assuming this is ColMajor):
+ // - 1st dim: kernel filters
+ // - 2nd dim: output height
+ // - 3rd dim: output width
+ // - 4th dim and beyond: everything else including batch size
+ DSizes<TensorIndex, NumDims> post_contract_dims;
+ if (isColMajor) {
+ post_contract_dims[0] = kernelFilters;
+ post_contract_dims[1] = out_height;
+ post_contract_dims[2] = out_width;
+ for (int i = 3; i < NumDims; ++i) {
+ post_contract_dims[i] = in.dimension(i);
+ }
+ } else {
+ post_contract_dims[NumDims - 1] = kernelFilters;
+ post_contract_dims[NumDims - 2] = out_height;
+ post_contract_dims[NumDims - 3] = out_width;
+ for (int i = 0; i < NumDims - 3; ++i) {
+ post_contract_dims[i] = in.dimension(i);
+ }
+ }
+
+ DSizes<TensorIndex, 2> kernel_dims;
+ if (isColMajor) {
+ kernel_dims[0] = kernelFilters;
+ kernel_dims[1] = kernelChannels * kernelRows * kernelCols;
+ } else {
+ kernel_dims[0] = kernelChannels * kernelRows * kernelCols;
+ kernel_dims[1] = kernelFilters;
+ }
+ // TODO(yangke): choose() is defined in TensorContraction.h -- consider
+ // moving it to somewhere more "common".
+ return choose(Cond<internal::traits<Input>::Layout == ColMajor>(),
+ kernel.reshape(kernel_dims).contract(input.extract_image_patches(kernelRows, kernelCols, stride, stride, in_stride, in_stride, padding_type).reshape(pre_contract_dims), contract_dims).reshape(post_contract_dims),
+ input.extract_image_patches(kernelRows, kernelCols, stride, stride, in_stride, in_stride, padding_type).reshape(pre_contract_dims).contract(kernel.reshape(kernel_dims), contract_dims).reshape(post_contract_dims));
+}
+
+} // end namespace Eigen
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_H_
diff --git a/tensorflow/core/kernels/eigen_spatial_convolutions_test.cc b/tensorflow/core/kernels/eigen_spatial_convolutions_test.cc
new file mode 100644
index 0000000000..f20287e73e
--- /dev/null
+++ b/tensorflow/core/kernels/eigen_spatial_convolutions_test.cc
@@ -0,0 +1,1215 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/eigen_spatial_convolutions.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/eigen_cuboid_convolution.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace Eigen {
+
+namespace {
+void EigenApprox(float a, float b) {
+ ASSERT_TRUE(std::abs(a - b) <= std::min(std::abs(a), std::abs(b)) * 1e-3);
+}
+static int ceil_div(int a, int b) { return (a + b - 1) / b; }
+}
+
+TEST(EigenSpatialConvolutionsTest, Simple) {
+ const int input_depth = 7;
+ const int input_rows = 4;
+ const int input_cols = 5;
+ const int output_depth = 10;
+ const int patch_rows = 3;
+ const int patch_cols = 4;
+ const int output_rows = input_rows;
+ const int output_cols = input_cols;
+
+ Tensor<float, 3> input(input_depth, input_rows, input_cols);
+ Tensor<float, 4> kernel(output_depth, input_depth, patch_rows, patch_cols);
+ Tensor<float, 3> result(output_depth, output_rows, output_cols);
+
+ input = input.constant(11.0f) + input.random();
+ kernel = kernel.constant(2.0f) + kernel.random();
+ result.setRandom();
+
+ result = SpatialConvolution(input, kernel);
+
+ EXPECT_EQ(result.dimension(0), output_depth);
+ EXPECT_EQ(result.dimension(1), output_rows);
+ EXPECT_EQ(result.dimension(2), output_cols);
+
+ for (int od = 0; od < output_depth; ++od) {
+ for (int i = 0; i < output_rows; ++i) {
+ for (int j = 0; j < output_cols; ++j) {
+ float expected = 0.0f;
+ for (int c = 0; c < patch_cols; ++c) {
+ for (int r = 0; r < patch_rows; ++r) {
+ for (int id = 0; id < input_depth; ++id) {
+ if (r - 1 + i >= 0 && c - 1 + j >= 0 && r - 1 + i < output_rows &&
+ c - 1 + j < output_cols) {
+ expected +=
+ input(id, r - 1 + i, c - 1 + j) * kernel(od, id, r, c);
+ }
+ }
+ }
+ }
+ EigenApprox(result(od, i, j), expected);
+ }
+ }
+ }
+}
+
+TEST(EigenSpatialConvolutionsTest, SimpleRowMajor) {
+ const int input_depth = 7;
+ const int input_rows = 4;
+ const int input_cols = 5;
+ const int output_depth = 10;
+ const int patch_rows = 3;
+ const int patch_cols = 4;
+ const int output_rows = input_rows;
+ const int output_cols = input_cols;
+
+ Tensor<float, 3, RowMajor> input(input_cols, input_rows, input_depth);
+ Tensor<float, 4, RowMajor> kernel(patch_cols, patch_rows, input_depth,
+ output_depth);
+ Tensor<float, 3, RowMajor> result(output_cols, output_rows, output_depth);
+ input = input.constant(11.0f) + input.random();
+ kernel = kernel.constant(2.0f) + kernel.random();
+ result.setRandom();
+
+ result = SpatialConvolution(input, kernel);
+
+ EXPECT_EQ(result.dimension(0), output_cols);
+ EXPECT_EQ(result.dimension(1), output_rows);
+ EXPECT_EQ(result.dimension(2), output_depth);
+
+ for (int od = 0; od < output_depth; ++od) {
+ for (int i = 0; i < output_rows; ++i) {
+ for (int j = 0; j < output_cols; ++j) {
+ float expected = 0.0f;
+ for (int c = 0; c < patch_cols; ++c) {
+ for (int r = 0; r < patch_rows; ++r) {
+ for (int id = 0; id < input_depth; ++id) {
+ if (r - 1 + i >= 0 && c - 1 + j >= 0 && r - 1 + i < output_rows &&
+ c - 1 + j < output_cols) {
+ expected +=
+ input(c - 1 + j, r - 1 + i, id) * kernel(c, r, id, od);
+ }
+ }
+ }
+ }
+ EigenApprox(result(j, i, od), expected);
+ }
+ }
+ }
+}
+
+TEST(EigenSpatialConvolutionsTest, BatchedSpatialConvolution) {
+ Tensor<float, 4> input(10, 5, 5, 13);
+ Tensor<float, 4> kernel(7, 10, 3, 3);
+ Tensor<float, 4> result(7, 5, 5, 13);
+ input = input.constant(11.0f) + input.random();
+ kernel = kernel.constant(2.0f) + kernel.random();
+ result.setRandom();
+
+ result = SpatialConvolution(input, kernel);
+
+ EXPECT_EQ(result.dimension(0), 7);
+ EXPECT_EQ(result.dimension(1), 5);
+ EXPECT_EQ(result.dimension(2), 5);
+
+ for (int b = 0; b < 13; ++b) {
+ for (int od = 0; od < 7; ++od) {
+ for (int i = 0; i < 5; ++i) {
+ for (int j = 0; j < 5; ++j) {
+ float expected = 0.0f;
+ for (int c = 0; c < 3; ++c) {
+ for (int r = 0; r < 3; ++r) {
+ for (int id = 0; id < 10; ++id) {
+ if (r - 1 + i >= 0 && c - 1 + j >= 0 && r - 1 + i < 5 &&
+ c - 1 + j < 5) {
+ expected +=
+ input(id, r - 1 + i, c - 1 + j, b) * kernel(od, id, r, c);
+ }
+ }
+ }
+ }
+ EigenApprox(result(od, i, j, b), expected);
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenSpatialConvolutionsTest, BatchedSpatialConvolutionRowMajor) {
+ Tensor<float, 4, RowMajor> input(13, 5, 5, 10);
+ Tensor<float, 4, RowMajor> kernel(3, 3, 10, 7);
+ Tensor<float, 4, RowMajor> result(13, 5, 5, 7);
+ input = input.constant(11.0f) + input.random();
+ kernel = kernel.constant(2.0f) + kernel.random();
+ result.setRandom();
+
+ result = SpatialConvolution(input, kernel);
+
+ EXPECT_EQ(result.dimension(1), 5);
+ EXPECT_EQ(result.dimension(2), 5);
+ EXPECT_EQ(result.dimension(3), 7);
+
+ for (int b = 0; b < 13; ++b) {
+ for (int od = 0; od < 7; ++od) {
+ for (int i = 0; i < 5; ++i) {
+ for (int j = 0; j < 5; ++j) {
+ float expected = 0.0f;
+ for (int c = 0; c < 3; ++c) {
+ for (int r = 0; r < 3; ++r) {
+ for (int id = 0; id < 10; ++id) {
+ if (r - 1 + i >= 0 && c - 1 + j >= 0 && r - 1 + i < 5 &&
+ c - 1 + j < 5) {
+ expected +=
+ input(b, c - 1 + j, r - 1 + i, id) * kernel(c, r, id, od);
+ }
+ }
+ }
+ }
+ EigenApprox(result(b, j, i, od), expected);
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenSpatialConvolutionsTest, ValidSpatialConvolution) {
+ const int input_depth = 10;
+ const int input_rows = 5;
+ const int input_cols = 5;
+ const int num_batches = 13;
+ const int output_depth = 7;
+ const int patch_rows = 4;
+ const int patch_cols = 4;
+ const int output_rows = input_rows - patch_rows + 1;
+ const int output_cols = input_cols - patch_cols + 1;
+
+ Tensor<float, 4> input(input_depth, input_rows, input_cols, num_batches);
+ Tensor<float, 4> kernel(output_depth, input_depth, patch_rows, patch_cols);
+ Tensor<float, 4> result(output_depth, output_rows, output_cols, num_batches);
+ input = input.constant(11.0f) + input.random();
+ kernel = kernel.constant(2.0f) + kernel.random();
+ result.setRandom();
+
+ // Apply a spatial convolution using a 4x4 kernel, valid padding, and a stride
+ // of 1.
+ const int stride = 1;
+ result = SpatialConvolution(input, kernel, stride, PADDING_VALID);
+
+ EXPECT_EQ(result.dimension(0), output_depth);
+ EXPECT_EQ(result.dimension(1), output_rows);
+ EXPECT_EQ(result.dimension(2), output_cols);
+ EXPECT_EQ(result.dimension(3), num_batches);
+
+ for (int b = 0; b < num_batches; ++b) {
+ for (int od = 0; od < output_depth; ++od) {
+ for (int i = 0; i < output_rows; ++i) {
+ for (int j = 0; j < output_cols; ++j) {
+ float expected = 0.0f;
+ for (int c = 0; c < patch_cols; ++c) {
+ for (int r = 0; r < patch_rows; ++r) {
+ for (int id = 0; id < input_depth; ++id) {
+ expected += input(id, r + i, c + j, b) * kernel(od, id, r, c);
+ }
+ }
+ }
+ if (result(od, i, j, b) != expected) {
+ std::cout << "at od=" << od << " b=" << b << " i=" << i
+ << " j=" << j << " " << result(od, i, j, b) << " vs "
+ << expected << std::endl;
+ }
+ EigenApprox(result(od, i, j, b), expected);
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenSpatialConvolutionsTest, ValidSpatialConvolutionRowMajor) {
+ const int input_depth = 10;
+ const int input_rows = 5;
+ const int input_cols = 5;
+ const int num_batches = 13;
+ const int output_depth = 7;
+ const int patch_rows = 4;
+ const int patch_cols = 4;
+ const int output_rows = input_rows - patch_rows + 1;
+ const int output_cols = input_cols - patch_cols + 1;
+
+ Tensor<float, 4, RowMajor> input(num_batches, input_cols, input_rows,
+ input_depth);
+ Tensor<float, 4, RowMajor> kernel(patch_cols, patch_rows, input_depth,
+ output_depth);
+ Tensor<float, 4, RowMajor> result(num_batches, output_cols, output_rows,
+ output_depth);
+
+ input = input.constant(11.0f) + input.random();
+ kernel = kernel.constant(2.0f) + kernel.random();
+ result.setRandom();
+
+ // Apply a spatial convolution using a 4x4 kernel, valid padding, and a stride
+ // of 1.
+ const int stride = 1;
+ result = SpatialConvolution(input, kernel, stride, PADDING_VALID);
+
+ EXPECT_EQ(result.dimension(0), num_batches);
+ EXPECT_EQ(result.dimension(1), output_cols);
+ EXPECT_EQ(result.dimension(2), output_rows);
+ EXPECT_EQ(result.dimension(3), output_depth);
+
+ for (int b = 0; b < num_batches; ++b) {
+ for (int od = 0; od < output_depth; ++od) {
+ for (int i = 0; i < output_rows; ++i) {
+ for (int j = 0; j < output_cols; ++j) {
+ float expected = 0.0f;
+ for (int c = 0; c < patch_rows; ++c) {
+ for (int r = 0; r < patch_cols; ++r) {
+ for (int id = 0; id < input_depth; ++id) {
+ expected += input(b, c + j, r + i, id) * kernel(c, r, id, od);
+ }
+ }
+ }
+ if (result(b, j, i, od) != expected) {
+ std::cout << "at od=" << od << " b=" << b << " i=" << i
+ << " j=" << j << " " << result(b, j, i, od) << " vs "
+ << expected << std::endl;
+ }
+ EigenApprox(result(b, j, i, od), expected);
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenSpatialConvolutionsTest, StridedSpatialConvolution) {
+ const int input_depth = 10;
+ const int input_rows = 5;
+ const int input_cols = 5;
+ const int num_batches = 13;
+ const int output_depth = 7;
+ const int patch_rows = 3;
+ const int patch_cols = 3;
+ const int output_rows = 2;
+ const int output_cols = 2;
+
+ Tensor<float, 4> input(input_depth, input_rows, input_cols, num_batches);
+ Tensor<float, 4> kernel(output_depth, input_depth, patch_rows, patch_cols);
+ Tensor<float, 4> result(output_depth, output_rows, output_cols, num_batches);
+ input = input.constant(11.0f) + input.random();
+ kernel = kernel.constant(2.0f) + kernel.random();
+ result.setRandom();
+
+ // Apply a spatial convolution using a 3x3 kernel, valid padding, and a stride
+ // of 2.
+ int stride = 2;
+ result = SpatialConvolution(input, kernel, stride, PADDING_VALID);
+
+ EXPECT_EQ(result.dimension(0), output_depth);
+ EXPECT_EQ(result.dimension(1), output_rows);
+ EXPECT_EQ(result.dimension(2), output_cols);
+ EXPECT_EQ(result.dimension(3), num_batches);
+
+ for (int b = 0; b < num_batches; ++b) {
+ for (int od = 0; od < output_depth; ++od) {
+ for (int i = 0; i < output_rows; ++i) {
+ for (int j = 0; j < output_cols; ++j) {
+ float expected = 0.0f;
+ for (int c = 0; c < patch_cols; ++c) {
+ for (int r = 0; r < patch_rows; ++r) {
+ for (int id = 0; id < input_depth; ++id) {
+ expected += input(id, r + stride * i, c + stride * j, b) *
+ kernel(od, id, r, c);
+ }
+ }
+ }
+ EigenApprox(result(od, i, j, b), expected);
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenSpatialConvolutionsTest, StridedSpatialConvolutionRowMajor) {
+ const int input_depth = 10;
+ const int input_rows = 5;
+ const int input_cols = 5;
+ const int num_batches = 13;
+ const int output_depth = 7;
+ const int patch_rows = 3;
+ const int patch_cols = 3;
+ const int output_rows = 2;
+ const int output_cols = 2;
+
+ Tensor<float, 4, RowMajor> input(num_batches, input_cols, input_rows,
+ input_depth);
+ Tensor<float, 4, RowMajor> kernel(patch_cols, patch_rows, input_depth,
+ output_depth);
+ Tensor<float, 4, RowMajor> result(num_batches, output_cols, output_rows,
+ output_depth);
+ input = input.constant(11.0f) + input.random();
+ kernel = kernel.constant(2.0f) + kernel.random();
+ result.setRandom();
+
+ // Apply a spatial convolution using a 3x3 kernel, valid padding, and a stride
+ // of 2.
+ int stride = 2;
+ result = SpatialConvolution(input, kernel, stride, PADDING_VALID);
+
+ EXPECT_EQ(result.dimension(0), num_batches);
+ EXPECT_EQ(result.dimension(1), output_cols);
+ EXPECT_EQ(result.dimension(2), output_rows);
+ EXPECT_EQ(result.dimension(3), output_depth);
+
+ for (int b = 0; b < num_batches; ++b) {
+ for (int od = 0; od < output_depth; ++od) {
+ for (int i = 0; i < output_rows; ++i) {
+ for (int j = 0; j < output_cols; ++j) {
+ float expected = 0.0f;
+ for (int c = 0; c < patch_cols; ++c) {
+ for (int r = 0; r < patch_rows; ++r) {
+ for (int id = 0; id < input_depth; ++id) {
+ expected += input(b, c + stride * j, r + stride * i, id) *
+ kernel(c, r, id, od);
+ }
+ }
+ }
+ EigenApprox(result(b, j, i, od), expected);
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenSpatialConvolutionsTest, AtrousSpatial) {
+ const int input_depth = 10;
+ const int input_rows = 7;
+ const int input_cols = 7;
+ const int num_batches = 13;
+ const int output_depth = 7;
+ const int patch_rows = 3;
+ const int patch_cols = 3;
+ const int output_rows = 3;
+ const int output_cols = 3;
+
+ Tensor<float, 4> input(input_depth, input_rows, input_cols, num_batches);
+ Tensor<float, 4> kernel(output_depth, input_depth, patch_rows, patch_cols);
+ Tensor<float, 4> result(output_depth, output_rows, output_cols, num_batches);
+ input = input.constant(11.0f) + input.random();
+ kernel = kernel.constant(2.0f) + kernel.random();
+ result.setRandom();
+
+ // Apply a spatial convolution using a 3x3 kernel, valid padding
+ // output (standard) stride 1, and input (atrous) stride of 2.
+ int stride = 1;
+ int in_stride = 2;
+ result = SpatialConvolution(input, kernel, stride, PADDING_VALID, in_stride);
+
+ EXPECT_EQ(result.dimension(0), output_depth);
+ EXPECT_EQ(result.dimension(1), output_rows);
+ EXPECT_EQ(result.dimension(2), output_cols);
+ EXPECT_EQ(result.dimension(3), num_batches);
+
+ for (int b = 0; b < num_batches; ++b) {
+ for (int od = 0; od < output_depth; ++od) {
+ for (int i = 0; i < output_rows; ++i) {
+ for (int j = 0; j < output_cols; ++j) {
+ float expected = 0.0f;
+ for (int c = 0; c < patch_cols; ++c) {
+ for (int r = 0; r < patch_rows; ++r) {
+ for (int id = 0; id < input_depth; ++id) {
+ expected += input(id, in_stride * r + stride * i,
+ in_stride * c + stride * j, b) *
+ kernel(od, id, r, c);
+ }
+ }
+ }
+ EigenApprox(result(od, i, j, b), expected);
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenSpatialConvolutionsTest, AtrousSpatialRowMajor) {
+ const int input_depth = 10;
+ const int input_rows = 7;
+ const int input_cols = 7;
+ const int num_batches = 13;
+ const int output_depth = 7;
+ const int patch_rows = 3;
+ const int patch_cols = 3;
+ const int output_rows = 3;
+ const int output_cols = 3;
+
+ Tensor<float, 4, RowMajor> input(num_batches, input_cols, input_rows,
+ input_depth);
+ Tensor<float, 4, RowMajor> kernel(patch_cols, patch_rows, input_depth,
+ output_depth);
+ Tensor<float, 4, RowMajor> result(num_batches, output_cols, output_rows,
+ output_depth);
+ input = input.constant(11.0f) + input.random();
+ kernel = kernel.constant(2.0f) + kernel.random();
+ result.setRandom();
+
+ // Apply a spatial convolution using a 3x3 kernel, valid padding
+ // output (standard) stride 1, and input (atrous) stride of 2.
+ int stride = 1;
+ int in_stride = 2;
+ result = SpatialConvolution(input, kernel, stride, PADDING_VALID, in_stride);
+
+ EXPECT_EQ(result.dimension(0), num_batches);
+ EXPECT_EQ(result.dimension(1), output_cols);
+ EXPECT_EQ(result.dimension(2), output_rows);
+ EXPECT_EQ(result.dimension(3), output_depth);
+
+ for (int b = 0; b < num_batches; ++b) {
+ for (int od = 0; od < output_depth; ++od) {
+ for (int i = 0; i < output_rows; ++i) {
+ for (int j = 0; j < output_cols; ++j) {
+ float expected = 0.0f;
+ for (int c = 0; c < patch_cols; ++c) {
+ for (int r = 0; r < patch_rows; ++r) {
+ for (int id = 0; id < input_depth; ++id) {
+ expected += input(b, in_stride * c + stride * j,
+ in_stride * r + stride * i, id) *
+ kernel(c, r, id, od);
+ }
+ }
+ }
+ EigenApprox(result(b, j, i, od), expected);
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenSpatialConvolutionsTest, Cuboid) {
+ const int in_channels = 10;
+ const int in_depth = 5;
+ const int in_rows = 8;
+ const int in_cols = 7;
+
+ const int kern_filters = 7;
+ const int kern_depth = 3;
+ const int kern_width = 4;
+ const int kern_height = 4;
+
+ const int out_depth = in_depth;
+ const int out_height = in_rows;
+ const int out_width = in_cols;
+
+ Tensor<float, 4> input(in_channels, in_depth, in_rows, in_cols);
+ Tensor<float, 5> kernel(kern_filters, in_channels, kern_depth, kern_height,
+ kern_width);
+ Tensor<float, 4> result(kern_filters, out_depth, out_height, out_width);
+ input = input.constant(11.0f) + input.random();
+ kernel = kernel.constant(2.0f) + kernel.random();
+ result.setRandom();
+
+ result = CuboidConvolution(input, kernel);
+
+ EXPECT_EQ(result.dimension(0), kern_filters);
+ EXPECT_EQ(result.dimension(1), out_depth);
+ EXPECT_EQ(result.dimension(2), out_height);
+ EXPECT_EQ(result.dimension(3), out_width);
+
+ const int off_p = kern_depth / 2;
+ const int off_r = kern_height / 2;
+ const int off_c = kern_width / 2;
+
+ for (int od = 0; od < kern_filters; ++od) {
+ for (int i = 0; i < out_depth; ++i) {
+ for (int j = 0; j < out_height; ++j) {
+ for (int k = 0; k < out_width; ++k) {
+ float expected = 0.0f;
+ for (int c = 0; c < kern_width; ++c) {
+ for (int r = 0; r < kern_height; ++r) {
+ for (int p = 0; p < kern_depth; ++p) {
+ for (int id = 0; id < in_channels; ++id) {
+ if (p - off_p + i >= 0 && r - off_r + j >= 0 &&
+ c - off_c + k >= 0 && p - off_p + i < in_depth &&
+ r - off_r + j < in_rows && c - off_c + k < in_cols) {
+ expected +=
+ input(id, p - off_p + i, r - off_r + j, c - off_c + k) *
+ kernel(od, id, p, r, c);
+ }
+ }
+ }
+ }
+ }
+ EigenApprox(result(od, i, j, k), expected);
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenSpatialConvolutionsTest, CuboidRowMajor) {
+ const int in_channels = 10;
+ const int in_depth = 5;
+ const int in_rows = 8;
+ const int in_cols = 7;
+
+ const int kern_filters = 7;
+ const int kern_depth = 3;
+ const int kern_width = 4;
+ const int kern_height = 4;
+
+ const int out_depth = in_depth;
+ const int out_height = in_rows;
+ const int out_width = in_cols;
+
+ Tensor<float, 4, RowMajor> input(in_cols, in_rows, in_depth, in_channels);
+ Tensor<float, 5, RowMajor> kernel(kern_width, kern_height, kern_depth,
+ in_channels, kern_filters);
+ Tensor<float, 4, RowMajor> result(out_width, out_height, out_depth,
+ kern_filters);
+ input = input.constant(11.0f) + input.random();
+ kernel = kernel.constant(2.0f) + kernel.random();
+ result.setRandom();
+
+ result = CuboidConvolution(input, kernel);
+
+ EXPECT_EQ(result.dimension(3), kern_filters);
+ EXPECT_EQ(result.dimension(2), out_depth);
+ EXPECT_EQ(result.dimension(1), out_height);
+ EXPECT_EQ(result.dimension(0), out_width);
+
+ const int off_p = kern_depth / 2;
+ const int off_r = kern_height / 2;
+ const int off_c = kern_width / 2;
+
+ for (int od = 0; od < kern_filters; ++od) {
+ for (int i = 0; i < out_depth; ++i) {
+ for (int j = 0; j < out_height; ++j) {
+ for (int k = 0; k < out_width; ++k) {
+ float expected = 0.0f;
+ for (int c = 0; c < kern_width; ++c) {
+ for (int r = 0; r < kern_height; ++r) {
+ for (int p = 0; p < kern_depth; ++p) {
+ for (int id = 0; id < in_channels; ++id) {
+ if (p - off_p + i >= 0 && r - off_r + j >= 0 &&
+ c - off_c + k >= 0 && p - off_p + i < in_depth &&
+ r - off_r + j < in_rows && c - off_c + k < in_cols) {
+ expected +=
+ input(c - off_c + k, r - off_r + j, p - off_p + i, id) *
+ kernel(c, r, p, id, od);
+ }
+ }
+ }
+ }
+ }
+ EigenApprox(result(k, j, i, od), expected);
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenSpatialConvolutionsTest, ValidCuboid) {
+ const int in_channels = 10;
+ const int in_depth = 5;
+ const int in_rows = 5;
+ const int in_cols = 5;
+
+ const int kern_filters = 7;
+ const int kern_depth = 3;
+ const int kern_width = 3;
+ const int kern_height = 3;
+
+ const int out_depth = 3;
+ const int out_height = 3;
+ const int out_width = 3;
+
+ Tensor<float, 4> input(in_channels, in_depth, in_rows, in_cols);
+ Tensor<float, 5> kernel(kern_filters, in_channels, kern_depth, kern_height,
+ kern_width);
+ Tensor<float, 4> result(kern_filters, out_depth, out_height, out_width);
+ input = input.constant(11.0f) + input.random();
+ kernel = kernel.constant(2.0f) + kernel.random();
+ result.setRandom();
+
+ result = CuboidConvolution(input, kernel, 1, 1, 1, PADDING_VALID);
+
+ EXPECT_EQ(result.dimension(0), kern_filters);
+ EXPECT_EQ(result.dimension(1), out_depth);
+ EXPECT_EQ(result.dimension(2), out_height);
+ EXPECT_EQ(result.dimension(3), out_width);
+
+ for (int od = 0; od < kern_filters; ++od) {
+ for (int i = 0; i < out_depth; ++i) {
+ for (int j = 0; j < out_height; ++j) {
+ for (int k = 0; k < out_width; ++k) {
+ float expected = 0.0f;
+ for (int c = 0; c < kern_width; ++c) {
+ for (int r = 0; r < kern_height; ++r) {
+ for (int p = 0; p < kern_depth; ++p) {
+ for (int id = 0; id < in_channels; ++id) {
+ expected +=
+ input(id, p + i, r + j, c + k) * kernel(od, id, p, r, c);
+ }
+ }
+ }
+ }
+ EigenApprox(result(od, i, j, k), expected);
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenSpatialConvolutionsTest, ValidCuboidRowMajor) {
+ const int in_channels = 10;
+ const int in_depth = 5;
+ const int in_rows = 5;
+ const int in_cols = 5;
+
+ const int kern_filters = 7;
+ const int kern_depth = 3;
+ const int kern_width = 3;
+ const int kern_height = 3;
+
+ const int out_depth = 3;
+ const int out_height = 3;
+ const int out_width = 3;
+
+ Tensor<float, 4, RowMajor> input(in_cols, in_rows, in_depth, in_channels);
+ Tensor<float, 5, RowMajor> kernel(kern_width, kern_height, kern_depth,
+ in_channels, kern_filters);
+ Tensor<float, 4, RowMajor> result(out_width, out_height, out_depth,
+ kern_filters);
+ input = input.constant(11.0f) + input.random();
+ kernel = kernel.constant(2.0f) + kernel.random();
+ result.setRandom();
+
+ result = CuboidConvolution(input, kernel, 1, 1, 1, PADDING_VALID);
+
+ EXPECT_EQ(result.dimension(3), kern_filters);
+ EXPECT_EQ(result.dimension(2), out_depth);
+ EXPECT_EQ(result.dimension(1), out_height);
+ EXPECT_EQ(result.dimension(0), out_width);
+
+ for (int od = 0; od < kern_filters; ++od) {
+ for (int i = 0; i < out_depth; ++i) {
+ for (int j = 0; j < out_height; ++j) {
+ for (int k = 0; k < out_width; ++k) {
+ float expected = 0.0f;
+ for (int c = 0; c < kern_width; ++c) {
+ for (int r = 0; r < kern_height; ++r) {
+ for (int p = 0; p < kern_depth; ++p) {
+ for (int id = 0; id < in_channels; ++id) {
+ expected +=
+ input(c + k, r + j, p + i, id) * kernel(c, r, p, id, od);
+ }
+ }
+ }
+ }
+ EigenApprox(result(k, j, i, od), expected);
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenSpatialConvolutionsTest, BatchedCuboid) {
+ const int batches = 2;
+ const int in_channels = 10;
+ const int in_depth = 5;
+ const int in_rows = 8;
+ const int in_cols = 7;
+
+ const int kern_filters = 7;
+ const int kern_depth = 3;
+ const int kern_width = 4;
+ const int kern_height = 4;
+
+ const int out_depth = in_depth;
+ const int out_height = in_rows;
+ const int out_width = in_cols;
+
+ Tensor<float, 5> input(in_channels, in_depth, in_rows, in_cols, batches);
+ Tensor<float, 5> kernel(kern_filters, in_channels, kern_depth, kern_height,
+ kern_width);
+ Tensor<float, 5> result(kern_filters, out_depth, out_height, out_width,
+ batches);
+ input = input.constant(11.0f) + input.random();
+ kernel = kernel.constant(2.0f) + kernel.random();
+ result.setRandom();
+
+ result = CuboidConvolution(input, kernel);
+
+ EXPECT_EQ(result.dimension(0), kern_filters);
+ EXPECT_EQ(result.dimension(1), out_depth);
+ EXPECT_EQ(result.dimension(2), out_height);
+ EXPECT_EQ(result.dimension(3), out_width);
+ EXPECT_EQ(result.dimension(4), batches);
+
+ const int off_p = kern_depth / 2;
+ const int off_r = kern_height / 2;
+ const int off_c = kern_width / 2;
+
+ for (int b = 0; b < batches; b++) {
+ for (int od = 0; od < kern_filters; ++od) {
+ for (int i = 0; i < out_depth; ++i) {
+ for (int j = 0; j < out_height; ++j) {
+ for (int k = 0; k < out_width; ++k) {
+ float expected = 0.0f;
+ for (int c = 0; c < kern_width; ++c) {
+ for (int r = 0; r < kern_height; ++r) {
+ for (int p = 0; p < kern_depth; ++p) {
+ for (int id = 0; id < in_channels; ++id) {
+ if (p - off_p + i >= 0 && r - off_r + j >= 0 &&
+ c - off_c + k >= 0 && p - off_p + i < in_depth &&
+ r - off_r + j < in_rows && c - off_c + k < in_cols) {
+ expected += input(id, p - off_p + i, r - off_r + j,
+ c - off_c + k, b) *
+ kernel(od, id, p, r, c);
+ }
+ }
+ }
+ }
+ }
+ EigenApprox(result(od, i, j, k, b), expected);
+ }
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenSpatialConvolutionsTest, BatchedCuboidRowMajor) {
+ const int batches = 2;
+ const int in_channels = 10;
+ const int in_depth = 5;
+ const int in_rows = 8;
+ const int in_cols = 7;
+
+ const int kern_filters = 7;
+ const int kern_depth = 3;
+ const int kern_width = 4;
+ const int kern_height = 4;
+
+ const int out_depth = in_depth;
+ const int out_height = in_rows;
+ const int out_width = in_cols;
+
+ Tensor<float, 5, RowMajor> input(batches, in_cols, in_rows, in_depth,
+ in_channels);
+ Tensor<float, 5, RowMajor> kernel(kern_width, kern_height, kern_depth,
+ in_channels, kern_filters);
+ Tensor<float, 5, RowMajor> result(batches, out_width, out_height, out_depth,
+ kern_filters);
+ input = input.constant(11.0f) + input.random();
+ kernel = kernel.constant(2.0f) + kernel.random();
+ result.setRandom();
+
+ result = CuboidConvolution(input, kernel);
+
+ EXPECT_EQ(result.dimension(4), kern_filters);
+ EXPECT_EQ(result.dimension(3), out_depth);
+ EXPECT_EQ(result.dimension(2), out_height);
+ EXPECT_EQ(result.dimension(1), out_width);
+ EXPECT_EQ(result.dimension(0), batches);
+
+ const int off_p = kern_depth / 2;
+ const int off_r = kern_height / 2;
+ const int off_c = kern_width / 2;
+
+ for (int b = 0; b < batches; b++) {
+ for (int od = 0; od < kern_filters; ++od) {
+ for (int i = 0; i < out_depth; ++i) {
+ for (int j = 0; j < out_height; ++j) {
+ for (int k = 0; k < out_width; ++k) {
+ float expected = 0.0f;
+ for (int c = 0; c < kern_width; ++c) {
+ for (int r = 0; r < kern_height; ++r) {
+ for (int p = 0; p < kern_depth; ++p) {
+ for (int id = 0; id < in_channels; ++id) {
+ if (p - off_p + i >= 0 && r - off_r + j >= 0 &&
+ c - off_c + k >= 0 && p - off_p + i < in_depth &&
+ r - off_r + j < in_rows && c - off_c + k < in_cols) {
+ expected += input(b, c - off_c + k, r - off_r + j,
+ p - off_p + i, id) *
+ kernel(c, r, p, id, od);
+ }
+ }
+ }
+ }
+ }
+ EigenApprox(result(b, k, j, i, od), expected);
+ }
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenSpatialConvolutionsTest, StridedValidCuboid) {
+ const int in_channels = 10;
+ const int in_depth = 8;
+ const int in_rows = 7;
+ const int in_cols = 5;
+
+ const int kern_filters = 7;
+ const int kern_depth = 3;
+ const int kern_width = 3;
+ const int kern_height = 3;
+
+ const int out_depth = 3;
+ const int out_height = 3;
+ const int out_width = 2;
+
+ Tensor<float, 4> input(in_channels, in_depth, in_rows, in_cols);
+ Tensor<float, 5> kernel(kern_filters, in_channels, kern_depth, kern_height,
+ kern_width);
+ Tensor<float, 4> result(kern_filters, out_depth, out_height, out_width);
+ input = input.constant(11.0f) + input.random();
+ kernel = kernel.constant(2.0f) + kernel.random();
+ result.setRandom();
+
+ const int stride = 2;
+ result =
+ CuboidConvolution(input, kernel, stride, stride, stride, PADDING_VALID);
+
+ EXPECT_EQ(result.dimension(0), kern_filters);
+ EXPECT_EQ(result.dimension(1), out_depth);
+ EXPECT_EQ(result.dimension(2), out_height);
+ EXPECT_EQ(result.dimension(3), out_width);
+
+ for (int od = 0; od < kern_filters; ++od) {
+ for (int i = 0; i < out_depth; ++i) {
+ for (int j = 0; j < out_height; ++j) {
+ for (int k = 0; k < out_width; ++k) {
+ float expected = 0.0f;
+ for (int c = 0; c < kern_width; ++c) {
+ for (int r = 0; r < kern_height; ++r) {
+ for (int p = 0; p < kern_depth; ++p) {
+ for (int id = 0; id < in_channels; ++id) {
+ expected += input(id, p + stride * i, r + stride * j,
+ c + stride * k) *
+ kernel(od, id, p, r, c);
+ }
+ }
+ }
+ }
+ EigenApprox(result(od, i, j, k), expected);
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenSpatialConvolutionsTest, StridedValidCuboidRowMajor) {
+ const int in_channels = 10;
+ const int in_depth = 8;
+ const int in_rows = 7;
+ const int in_cols = 5;
+
+ const int kern_filters = 7;
+ const int kern_depth = 3;
+ const int kern_width = 3;
+ const int kern_height = 3;
+
+ const int out_depth = 3;
+ const int out_height = 3;
+ const int out_width = 2;
+
+ Tensor<float, 4, RowMajor> input(in_cols, in_rows, in_depth, in_channels);
+ Tensor<float, 5, RowMajor> kernel(kern_width, kern_height, kern_depth,
+ in_channels, kern_filters);
+ Tensor<float, 4, RowMajor> result(out_width, out_height, out_depth,
+ kern_filters);
+ input = input.constant(11.0f) + input.random();
+ kernel = kernel.constant(2.0f) + kernel.random();
+ result.setRandom();
+
+ const int stride = 2;
+ result =
+ CuboidConvolution(input, kernel, stride, stride, stride, PADDING_VALID);
+
+ EXPECT_EQ(result.dimension(3), kern_filters);
+ EXPECT_EQ(result.dimension(2), out_depth);
+ EXPECT_EQ(result.dimension(1), out_height);
+ EXPECT_EQ(result.dimension(0), out_width);
+
+ for (int od = 0; od < kern_filters; ++od) {
+ for (int i = 0; i < out_depth; ++i) {
+ for (int j = 0; j < out_height; ++j) {
+ for (int k = 0; k < out_width; ++k) {
+ float expected = 0.0f;
+ for (int c = 0; c < kern_width; ++c) {
+ for (int r = 0; r < kern_height; ++r) {
+ for (int p = 0; p < kern_depth; ++p) {
+ for (int id = 0; id < in_channels; ++id) {
+ expected += input(c + stride * k, r + stride * j,
+ p + stride * i, id) *
+ kernel(c, r, p, id, od);
+ }
+ }
+ }
+ }
+ EigenApprox(result(k, j, i, od), expected);
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenSpatialConvolutionsTest, StridedSameCuboid) {
+ const int in_channels = 10;
+ const int in_depth = 8;
+ const int in_rows = 7;
+ const int in_cols = 5;
+
+ const int kern_filters = 7;
+ const int kern_depth = 3;
+ const int kern_width = 3;
+ const int kern_height = 3;
+
+ const int stride = 2;
+ const int out_depth = ceil_div(in_depth, stride);
+ const int out_height = ceil_div(in_rows, stride);
+ const int out_width = ceil_div(in_cols, stride);
+
+ Tensor<float, 4> input(in_channels, in_depth, in_rows, in_cols);
+ Tensor<float, 5> kernel(kern_filters, in_channels, kern_depth, kern_height,
+ kern_width);
+ Tensor<float, 4> result(kern_filters, out_depth, out_height, out_width);
+ input = input.constant(11.0f) + input.random();
+ kernel = kernel.constant(2.0f) + kernel.random();
+ result.setRandom();
+
+ result =
+ CuboidConvolution(input, kernel, stride, stride, stride, PADDING_SAME);
+
+ EXPECT_EQ(result.dimension(0), kern_filters);
+ EXPECT_EQ(result.dimension(1), out_depth);
+ EXPECT_EQ(result.dimension(2), out_height);
+ EXPECT_EQ(result.dimension(3), out_width);
+
+ const int pad_p = out_depth * stride - in_depth + kern_depth - 1;
+ const int pad_r = out_height * stride - in_rows + kern_height - 1;
+ const int pad_c = out_width * stride - in_cols + kern_width - 1;
+
+ // Number of pixels the input is extended with at the lower end in every
+ // dimension.
+ const int dp = pad_p - pad_p / 2;
+ const int dr = pad_r - pad_r / 2;
+ const int dc = pad_c - pad_c / 2;
+
+ for (int od = 0; od < kern_filters; ++od) {
+ for (int i = 0; i < out_depth; ++i) {
+ for (int j = 0; j < out_height; ++j) {
+ for (int k = 0; k < out_width; ++k) {
+ float expected = 0.0f;
+ for (int c = 0; c < kern_width; ++c) {
+ for (int r = 0; r < kern_height; ++r) {
+ for (int p = 0; p < kern_depth; ++p) {
+ for (int id = 0; id < in_channels; ++id) {
+ const int in_p = p - dp + i * stride;
+ const int in_r = r - dr + j * stride;
+ const int in_c = c - dc + k * stride;
+ if (in_p >= 0 && in_r >= 0 && in_c >= 0 && in_p < in_depth &&
+ in_r < in_rows && in_c < in_cols) {
+ expected +=
+ input(id, in_p, in_r, in_c) * kernel(od, id, p, r, c);
+ }
+ }
+ }
+ }
+ }
+ EigenApprox(result(od, i, j, k), expected);
+ }
+ }
+ }
+ }
+}
+
+TEST(EigenSpatialConvolutionsTest, StridedSameCuboidRowMajor) {
+ const int in_channels = 10;
+ const int in_depth = 8;
+ const int in_rows = 7;
+ const int in_cols = 5;
+
+ const int kern_filters = 7;
+ const int kern_depth = 3;
+ const int kern_width = 3;
+ const int kern_height = 3;
+
+ const int stride = 2;
+ const int out_depth = ceil_div(in_depth, stride);
+ const int out_height = ceil_div(in_rows, stride);
+ const int out_width = ceil_div(in_cols, stride);
+
+ Tensor<float, 4, RowMajor> input(in_cols, in_rows, in_depth, in_channels);
+ Tensor<float, 5, RowMajor> kernel(kern_width, kern_height, kern_depth,
+ in_channels, kern_filters);
+ Tensor<float, 4, RowMajor> result(out_width, out_height, out_depth,
+ kern_filters);
+ input = input.constant(11.0f) + input.random();
+ kernel = kernel.constant(2.0f) + kernel.random();
+ result.setRandom();
+
+ result =
+ CuboidConvolution(input, kernel, stride, stride, stride, PADDING_SAME);
+
+ EXPECT_EQ(result.dimension(3), kern_filters);
+ EXPECT_EQ(result.dimension(2), out_depth);
+ EXPECT_EQ(result.dimension(1), out_height);
+ EXPECT_EQ(result.dimension(0), out_width);
+
+ const int pad_p = out_depth * stride - in_depth + kern_depth - 1;
+ const int pad_r = out_height * stride - in_rows + kern_height - 1;
+ const int pad_c = out_width * stride - in_cols + kern_width - 1;
+
+ // Number of pixels the input is extended with at the lower end in every
+ // dimension.
+ const int dp = pad_p - pad_p / 2;
+ const int dr = pad_r - pad_r / 2;
+ const int dc = pad_c - pad_c / 2;
+
+ for (int od = 0; od < kern_filters; ++od) {
+ for (int i = 0; i < out_depth; ++i) {
+ for (int j = 0; j < out_height; ++j) {
+ for (int k = 0; k < out_width; ++k) {
+ float expected = 0.0f;
+ for (int c = 0; c < kern_width; ++c) {
+ for (int r = 0; r < kern_height; ++r) {
+ for (int p = 0; p < kern_depth; ++p) {
+ for (int id = 0; id < in_channels; ++id) {
+ const int in_p = p - dp + i * stride;
+ const int in_r = r - dr + j * stride;
+ const int in_c = c - dc + k * stride;
+ if (in_p >= 0 && in_r >= 0 && in_c >= 0 && in_p < in_depth &&
+ in_r < in_rows && in_c < in_cols) {
+ expected +=
+ input(in_c, in_r, in_p, id) * kernel(c, r, p, id, od);
+ }
+ }
+ }
+ }
+ }
+ EigenApprox(result(k, j, i, od), expected);
+ }
+ }
+ }
+ }
+}
+
+// A test case discovered when testing backward spatial convolution where the
+// special tensor contraction mapper for spatial convolution contains a bug.
+TEST(EigenSpatialConvolutionsTest, SpatialConvContractionMapper) {
+ // We have a 3x4 input image with 2x2 patch and stride of 2.
+ // The output has size 1x2.
+ typedef Tensor<float, 1>::DimensionPair DimPair;
+ Tensor<float, 4> out(1, 1, 2, 1);
+ Tensor<float, 4> kern(1, 1, 2, 2);
+ for (int i = 0; i < kern.size(); ++i) {
+ kern.coeffRef(i) = static_cast<float>(i) + 1;
+ }
+ for (int i = 0; i < out.size(); ++i) {
+ out.coeffRef(i) = static_cast<float>(i) + 1;
+ }
+
+ DSizes<ptrdiff_t, 4> strides;
+ strides[0] = 1;
+ strides[1] = 2;
+ strides[2] = 2;
+ strides[3] = 1;
+
+ array<std::pair<ptrdiff_t, ptrdiff_t>, 4> paddings;
+ paddings[0] = std::make_pair(0, 0);
+ paddings[1] = std::make_pair(1, 2);
+ paddings[2] = std::make_pair(1, 1);
+ paddings[3] = std::make_pair(0, 0);
+
+ DSizes<ptrdiff_t, 3> out_dim;
+ out_dim[0] = 1;
+ out_dim[1] = 4;
+ out_dim[2] = 12;
+
+ array<bool, 4> kernel_reverse;
+ kernel_reverse[0] = false;
+ kernel_reverse[1] = false;
+ kernel_reverse[2] = true;
+ kernel_reverse[3] = true;
+
+ DSizes<ptrdiff_t, 3> k_dims;
+ k_dims[0] = 1;
+ k_dims[1] = 1;
+ k_dims[2] = 4;
+
+ array<DimPair, 2> contract_dims;
+ contract_dims[0] = DimPair(0, 0);
+ contract_dims[1] = DimPair(2, 1);
+
+ DSizes<ptrdiff_t, 4> in_dim;
+ in_dim[0] = 1;
+ in_dim[1] = 3;
+ in_dim[2] = 4;
+ in_dim[3] = 1;
+
+ DSizes<ptrdiff_t, 2> in_dbg_dim;
+ in_dbg_dim[0] = 3;
+ in_dbg_dim[1] = 4;
+
+ DSizes<ptrdiff_t, 2> out_dbg_dim;
+ out_dbg_dim[0] = 4;
+ out_dbg_dim[1] = 12;
+
+ // This is the formula for computing the backward prop for input with a
+ // spatial convolution.
+ Tensor<float, 4> direct =
+ kern.reverse(kernel_reverse)
+ .reshape(k_dims)
+ .contract(
+ out.extract_image_patches(2, 2, 1, 1, 1, 1, 2, 2, 1, 2, 1, 1, 0)
+ .reshape(out_dim),
+ contract_dims)
+ .reshape(in_dim);
+
+ Tensor<float, 4> indirect =
+ kern.reverse(kernel_reverse)
+ .reshape(k_dims)
+ .contract(
+ out.inflate(strides)
+ .pad(paddings)
+ .extract_image_patches(2, 2, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0)
+ .reshape(out_dim),
+ contract_dims)
+ .reshape(in_dim);
+
+ eigen_assert(dimensions_match(direct.dimensions(), indirect.dimensions()));
+ for (size_t i = 0; i < direct.dimensions().TotalSize(); ++i) {
+ EigenApprox(direct.data()[i], indirect.data()[i]);
+ }
+ EigenApprox(1.0f, direct(0, 0, 0, 0));
+ EigenApprox(3.0f, direct(0, 0, 1, 0));
+ EigenApprox(2.0f, direct(0, 0, 2, 0));
+ EigenApprox(6.0f, direct(0, 0, 3, 0));
+
+ EigenApprox(2.0f, direct(0, 1, 0, 0));
+ EigenApprox(4.0f, direct(0, 1, 1, 0));
+ EigenApprox(4.0f, direct(0, 1, 2, 0));
+ EigenApprox(8.0f, direct(0, 1, 3, 0));
+}
+
+} // namespace Eigen
diff --git a/tensorflow/core/kernels/maxpooling_op.cc b/tensorflow/core/kernels/maxpooling_op.cc
index b6755c61a5..97cf15b5dd 100644
--- a/tensorflow/core/kernels/maxpooling_op.cc
+++ b/tensorflow/core/kernels/maxpooling_op.cc
@@ -20,7 +20,6 @@ limitations under the License.
#include "tensorflow/core/kernels/maxpooling_op.h"
#include <vector>
-#include "third_party/eigen3/unsupported/Eigen/CXX11/NeuralNetworks"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/numeric_op.h"
@@ -29,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_slice.h"
#include "tensorflow/core/kernels/conv_2d.h"
+#include "tensorflow/core/kernels/eigen_pooling.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/kernels/pooling_ops_common.h"
#include "tensorflow/core/lib/core/errors.h"
diff --git a/tensorflow/core/kernels/maxpooling_op.h b/tensorflow/core/kernels/maxpooling_op.h
index f94ed882b7..ec34337efd 100644
--- a/tensorflow/core/kernels/maxpooling_op.h
+++ b/tensorflow/core/kernels/maxpooling_op.h
@@ -17,8 +17,8 @@ limitations under the License.
#define TENSORFLOW_KERNELS_MAXPOOLING_OP_H_
// Functor definition for MaxPoolingOp, must be compilable by nvcc.
-#include "third_party/eigen3/unsupported/Eigen/CXX11/NeuralNetworks"
#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/eigen_pooling.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
diff --git a/tensorflow/core/kernels/maxpooling_op_gpu.h b/tensorflow/core/kernels/maxpooling_op_gpu.h
index b46a339392..4d8d0e7fa7 100644
--- a/tensorflow/core/kernels/maxpooling_op_gpu.h
+++ b/tensorflow/core/kernels/maxpooling_op_gpu.h
@@ -22,7 +22,6 @@ limitations under the License.
#define EIGEN_USE_GPU
-#include "third_party/eigen3/unsupported/Eigen/CXX11/NeuralNetworks"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/core/kernels/pooling_ops_common.h b/tensorflow/core/kernels/pooling_ops_common.h
index f9f16d96d8..21396464fb 100644
--- a/tensorflow/core/kernels/pooling_ops_common.h
+++ b/tensorflow/core/kernels/pooling_ops_common.h
@@ -18,7 +18,6 @@ limitations under the License.
#include <vector>
-#include "third_party/eigen3/unsupported/Eigen/CXX11/NeuralNetworks"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
diff --git a/tensorflow/core/kernels/pooling_ops_common_gpu.h b/tensorflow/core/kernels/pooling_ops_common_gpu.h
index a1d4c4504d..0ef55a9677 100644
--- a/tensorflow/core/kernels/pooling_ops_common_gpu.h
+++ b/tensorflow/core/kernels/pooling_ops_common_gpu.h
@@ -21,7 +21,6 @@ limitations under the License.
#define TENSORFLOW_CORE_KERNELS_POOLING_OPS_COMMON_GPU_H_
#include <vector>
-#include "third_party/eigen3/unsupported/Eigen/CXX11/NeuralNetworks"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
diff --git a/tensorflow/core/kernels/resize_nearest_neighbor_op_gpu.h b/tensorflow/core/kernels/resize_nearest_neighbor_op_gpu.h
index 65b4b331d9..056d5a7316 100644
--- a/tensorflow/core/kernels/resize_nearest_neighbor_op_gpu.h
+++ b/tensorflow/core/kernels/resize_nearest_neighbor_op_gpu.h
@@ -20,7 +20,6 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_RESIZE_NEAREST_NEIGHBOR_OP_GPU_H_
#define TENSORFLOW_CORE_KERNELS_RESIZE_NEAREST_NEIGHBOR_OP_GPU_H_
-#include "third_party/eigen3/unsupported/Eigen/CXX11/NeuralNetworks"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/platform/types.h"