aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2018-08-31 15:44:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-31 15:48:53 -0700
commite0d39b135a24e577947bd90c6be45f54cd11f4f8 (patch)
treef9bf6cd822a9be41f525b6492625e41b342c8235 /tensorflow/core
parentdac56d2637fb8361de3da96c51c83ce3ed1ad4da (diff)
Benchmarks for CuboidConvolutions.
PiperOrigin-RevId: 211156403
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/kernels/eigen_benchmark.h96
-rw-r--r--tensorflow/core/kernels/eigen_benchmark_cpu_test.cc148
2 files changed, 230 insertions, 14 deletions
diff --git a/tensorflow/core/kernels/eigen_benchmark.h b/tensorflow/core/kernels/eigen_benchmark.h
index e4875ee0e3..c18b033466 100644
--- a/tensorflow/core/kernels/eigen_benchmark.h
+++ b/tensorflow/core/kernels/eigen_benchmark.h
@@ -18,7 +18,9 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h"
#include "tensorflow/core/kernels/eigen_backward_spatial_convolutions.h"
+#include "tensorflow/core/kernels/eigen_cuboid_convolution.h"
#include "tensorflow/core/kernels/eigen_spatial_convolutions.h"
#include "tensorflow/core/platform/test_benchmark.h"
@@ -115,4 +117,98 @@ class SpatialConvolutionBenchmarksSuite {
Device& device_;
};
+template <typename Scalar, typename Device>
+class CuboidConvolutionBenchmarksSuite {
+ public:
+ using Input = TTypes<float, 5>::ConstTensor;
+ using Filter = TTypes<float, 5>::ConstTensor;
+ using Output = TTypes<float, 5>::Tensor;
+
+ using Dimensions = Eigen::DSizes<Eigen::Index, 5>;
+
+ CuboidConvolutionBenchmarksSuite(int iters, Device& device)
+ : iters_(iters), device_(device) {}
+
+ Eigen::Index BufferSize(const Dimensions& dims) {
+ return dims.TotalSize() * sizeof(Scalar);
+ }
+
+ void CuboidConvolution(Dimensions input_dims, Dimensions filter_dims) {
+ Dimensions output_dims(input_dims[0], // batch
+ input_dims[1], // input_height
+ input_dims[2], // input_width
+ input_dims[3], // input_planes
+ filter_dims[4]); // filter_count
+
+ Scalar* input_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(input_dims)));
+ Scalar* filter_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims)));
+ Scalar* output_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(output_dims)));
+
+ device_.memset(input_data, 123, BufferSize(input_dims));
+ device_.memset(filter_data, 123, BufferSize(filter_dims));
+
+ Input input(input_data, input_dims);
+ Filter filter(filter_data, filter_dims);
+ Output output(output_data, output_dims);
+
+ ::tensorflow::testing::StartTiming();
+ for (int i = 0; i < iters_; ++i) {
+ output.device(device_) = Eigen::CuboidConvolution(input, filter);
+ tensorflow::testing::DoNotOptimize(output);
+ }
+ ::tensorflow::testing::StopTiming();
+
+ device_.deallocate(input_data);
+ device_.deallocate(filter_data);
+ device_.deallocate(output_data);
+ }
+
+ void CuboidConvolutionBackwardInput(Dimensions input_dims,
+ Dimensions filter_dims) {
+ Dimensions output_dims(input_dims[0], // batch
+ input_dims[1], // input_height
+ input_dims[2], // input_width
+ input_dims[3], // input_planes
+ filter_dims[4]); // filter_count
+
+ // Assuming that the convolution had SAME padding.
+ Eigen::Index input_rows = input_dims[1];
+ Eigen::Index input_cols = input_dims[2];
+ Eigen::Index input_planes = input_dims[3];
+
+ Scalar* input_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(input_dims)));
+ Scalar* filter_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims)));
+ Scalar* output_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(output_dims)));
+
+ device_.memset(input_data, 123, BufferSize(input_dims));
+ device_.memset(filter_data, 123, BufferSize(filter_dims));
+
+ Input input(input_data, input_dims);
+ Filter filter(filter_data, filter_dims);
+ Output output(output_data, output_dims);
+
+ ::tensorflow::testing::StartTiming();
+ for (int i = 0; i < iters_; ++i) {
+ output.device(device_) = Eigen::CuboidConvolutionBackwardInput(
+ filter, input, input_planes, input_rows, input_cols);
+ tensorflow::testing::DoNotOptimize(output);
+ }
+ ::tensorflow::testing::StopTiming();
+
+ device_.deallocate(input_data);
+ device_.deallocate(filter_data);
+ device_.deallocate(output_data);
+ }
+
+ private:
+ int iters_;
+ Device& device_;
+};
+
#endif // TENSORFLOW_CORE_KERNELS_EIGEN_BENCHMARK_H_
diff --git a/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc b/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc
index ddfb21dcb5..fde406ba31 100644
--- a/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc
+++ b/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc
@@ -23,6 +23,10 @@ limitations under the License.
Eigen::ThreadPool tp(threads); \
Eigen::ThreadPoolDevice device(&tp, threads)
+// -------------------------------------------------------------------------- //
+// Spatial Convolutions //
+// -------------------------------------------------------------------------- //
+
void SpatialConvolution(int iters, int num_threads,
/* Input dimensions: */
int input_batches, int input_height, int input_width,
@@ -86,22 +90,23 @@ void SpatialConvolutionBackwardInput(int iters, int num_threads,
// FH: filter height
// FW: filter width
-#define BM_NAME(prefix, NT, N, H, W, C, FC, FH, FW) \
+#define BM_SPATIAL_NAME(prefix, NT, N, H, W, C, FC, FH, FW) \
BM_##prefix##_CPU_##NT##T_in_##N##_##H##_##W##_##C##_f_##FC##_##FH##_##FW
-#define BM_SpatialConvolution(NT, N, H, W, C, FC, FH, FW, LABEL) \
- static void BM_NAME(SpatialConvolution, NT, N, H, W, C, FC, FH, \
- FW)(int iters) { \
- SpatialConvolution(iters, NT, N, H, W, C, FC, FH, FW); \
- } \
- BENCHMARK(BM_NAME(SpatialConvolution, NT, N, H, W, C, FC, FH, FW))
-
-#define BM_SpatialConvolutionBwdInput(NT, N, H, W, C, FC, FH, FW, LABEL) \
- static void BM_NAME(SpatialConvolutionBwdInput, NT, N, H, W, C, FC, FH, \
- FW)(int iters) { \
- SpatialConvolutionBackwardInput(iters, NT, N, H, W, C, FC, FH, FW); \
+#define BM_SpatialConvolution(NT, N, H, W, C, FC, FH, FW, LABEL) \
+ static void BM_SPATIAL_NAME(SpatialConvolution, NT, N, H, W, C, FC, FH, \
+ FW)(int iters) { \
+ SpatialConvolution(iters, NT, N, H, W, C, FC, FH, FW); \
} \
- BENCHMARK(BM_NAME(SpatialConvolutionBwdInput, NT, N, H, W, C, FC, FH, FW))
+ BENCHMARK(BM_SPATIAL_NAME(SpatialConvolution, NT, N, H, W, C, FC, FH, FW))
+
+#define BM_SpatialConvolutionBwdInput(NT, N, H, W, C, FC, FH, FW, LABEL) \
+ static void BM_SPATIAL_NAME(SpatialConvolutionBwdInput, NT, N, H, W, C, FC, \
+ FH, FW)(int iters) { \
+ SpatialConvolutionBackwardInput(iters, NT, N, H, W, C, FC, FH, FW); \
+ } \
+ BENCHMARK( \
+ BM_SPATIAL_NAME(SpatialConvolutionBwdInput, NT, N, H, W, C, FC, FH, FW))
#define BM_SpatialConvolutions(N, H, W, C, FC, FH, FW, LABEL) \
BM_SpatialConvolution(2, N, H, W, C, FC, FH, FW, LABEL); \
@@ -119,7 +124,7 @@ void SpatialConvolutionBackwardInput(int iters, int num_threads,
BM_SpatialConvolutions(32, // batch size
56, 56, 64, // input: height, width, depth
- 192, 3, 3, // filter: height, width, count
+ 192, 3, 3, // filter: count, height, width
"conv2_00");
BM_SpatialConvolutions(32, 28, 28, 96, 128, 3, 3, "conv3a_00_3x3");
@@ -168,3 +173,118 @@ BM_SpatialConvolutionsBwdInput(32, 7, 7, 160, 320, 3, 3, "conv5a_00_3x3");
BM_SpatialConvolutionsBwdInput(32, 7, 7, 48, 128, 5, 5,
"conv5a_00_5x5 / conv5_00_5x5");
BM_SpatialConvolutionsBwdInput(32, 7, 7, 192, 384, 3, 3, "conv5_00_3x3");
+
+// -------------------------------------------------------------------------- //
+// Cuboid Convolutions //
+// -------------------------------------------------------------------------- //
+
+void CuboidConvolution(int iters, int num_threads,
+ /* Input dimensions: */
+ int input_batches, int input_height, int input_width,
+ int input_planes, int input_depth,
+ /* Filter (kernel) dimensions: */
+ int filter_count, int filter_height, int filter_width,
+ int filter_planes) {
+ ::tensorflow::testing::StopTiming();
+
+ CREATE_THREAD_POOL(num_threads);
+
+ using Benchmark =
+ CuboidConvolutionBenchmarksSuite<float, Eigen::ThreadPoolDevice>;
+ auto benchmark = Benchmark(iters, device);
+
+ typename Benchmark::Dimensions input_dims(
+ input_batches, input_height, input_width, input_planes, input_depth);
+ typename Benchmark::Dimensions filter_dims(
+ filter_height, filter_width, filter_planes, input_depth, filter_count);
+
+ benchmark.CuboidConvolution(input_dims, filter_dims);
+
+ auto output_size = input_dims.TotalSize();
+ auto flops = output_size *
+ (input_depth * filter_height * filter_width * filter_planes);
+ ::tensorflow::testing::ItemsProcessed(flops * iters);
+}
+
+void CuboidConvolutionBackwardInput(int iters, int num_threads,
+ /* Input dimensions: */
+ int input_batches, int input_height,
+ int input_width, int input_planes,
+ int input_depth,
+ /* Filter (kernel) dimensions: */
+ int filter_count, int filter_height,
+ int filter_width, int filter_planes) {
+ ::tensorflow::testing::StopTiming();
+
+ CREATE_THREAD_POOL(num_threads);
+
+ using Benchmark =
+ CuboidConvolutionBenchmarksSuite<float, Eigen::ThreadPoolDevice>;
+ auto benchmark = Benchmark(iters, device);
+
+ typename Benchmark::Dimensions input_dims(
+ input_batches, input_height, input_width, input_planes, input_depth);
+ typename Benchmark::Dimensions filter_dims(
+ filter_height, filter_width, filter_planes, input_depth, filter_count);
+
+ benchmark.CuboidConvolutionBackwardInput(input_dims, filter_dims);
+
+ auto output_size = input_dims.TotalSize();
+ auto flops = output_size *
+ (input_depth * filter_height * filter_width * filter_planes);
+ ::tensorflow::testing::ItemsProcessed(flops * iters);
+}
+
+// Macro arguments names: --------------------------------------------------- //
+// NT: num threads
+// N: batch size
+// H: height
+// W: width
+// P: panes
+// C: channels
+// FC: filter count
+// FH: filter height
+// FW: filter width
+// FP: filter panes
+
+#define BM_CUBOID_NAME(p, NT, N, H, W, P, C, FC, FH, FW, FP) \
+ BM_##p##_CPU_##NT##T_in_##N##_##H##_##W##_##P##_##_##C##_f_##FC##_##FH##_##FW
+
+#define BM_CuboidConvolution(NT, N, H, W, P, C, FC, FH, FW, FP, LABEL) \
+ static void BM_CUBOID_NAME(CuboidConvolution, NT, N, H, W, P, C, FC, FH, FW, \
+ FP)(int iters) { \
+ CuboidConvolution(iters, NT, N, H, W, P, C, FC, FH, FW, FP); \
+ } \
+ BENCHMARK( \
+ BM_CUBOID_NAME(CuboidConvolution, NT, N, H, W, P, C, FC, FH, FW, FP))
+
+#define BM_CuboidConvolutionBwdInput(NT, N, H, W, P, C, FC, FH, FW, FP, LABEL) \
+ static void BM_CUBOID_NAME(CuboidConvolutionBwdInput, NT, N, H, W, P, C, FC, \
+ FH, FW, FP)(int iters) { \
+ CuboidConvolutionBackwardInput(iters, NT, N, H, W, P, C, FC, FH, FW, FP); \
+ } \
+ BENCHMARK(BM_CUBOID_NAME(CuboidConvolutionBwdInput, NT, N, H, W, P, C, FC, \
+ FH, FW, FP))
+
+#define BM_CuboidConvolutions(N, H, W, P, C, FC, FH, FW, FP, LABEL) \
+ BM_CuboidConvolution(2, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolution(4, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolution(8, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolution(16, N, H, W, P, C, FC, FH, FW, FP, LABEL);
+
+#define BM_CuboidConvolutionsBwdInput(N, H, W, P, C, FC, FH, FW, FP, LABEL) \
+ BM_CuboidConvolutionBwdInput(2, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolutionBwdInput(4, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolutionBwdInput(8, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolutionBwdInput(16, N, H, W, P, C, FC, FH, FW, FP, LABEL);
+
+// Random Cuboid Convolutions ----------------------------------------------- //
+// TODO(ezhulenev): find representative dims for cuboid convolutions (find
+// models using Conv3D ops).
+
+BM_CuboidConvolutions(16, // batch size
+ 25, 25, 25, 8, // input: height, width, panes, depth
+ 32, 5, 5, 5, // filter: count, height, width, panes
+ "conv3d");
+
+BM_CuboidConvolutionsBwdInput(16, 25, 25, 25, 8, 32, 5, 5, 5, "conv3d");