aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/matrix_band_part_op.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-15 10:05:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-15 10:08:44 -0700
commit2835ebaa9fb4e38c01c165da1b6dd80a6250fd3f (patch)
treeb2248666a2526eca5241d58f5b6d827ffb7b290a /tensorflow/core/kernels/matrix_band_part_op.cc
parent6d51dd66d1fd938fa0f95f5933169aaccd6aef76 (diff)
Optimize C++ kernels for the matrix_band_part op, which is used in various ops operating on triangular or banded matrices:
* Add benchmark for matrix_band_part. * Implement simple optimized CUDA kernel instead of calling Eigen generator. * Parallelize CPU kernel for matrix_band_part. * Support on-the-fly transposition in the underlying functors (to be used for future QR op in followup). Benchmarks: First column is of the form {device}_{shape}_{num_lower,num_upper} Test case Before After Speedup cpu_(10,16,16)_(-1,-1) 5.6505e-05 6.2108e-05 -9.92% cpu_(10,16,16)_(-1,0) 0.00010848 0.00010908 -0.55% cpu_(10,16,16)_(0,-1) 0.0001055 0.00011396 -8.02% cpu_(10,16,16)_(2,2) 0.000108 0.00011706 -8.39% cpu_(10,101,101)_(-1,-1) 0.00013697 6.0558e-05 +55.79% cpu_(10,101,101)_(-1,0) 0.00054002 0.00017703 +67.22% cpu_(10,101,101)_(0,-1) 0.00051188 0.00017607 +65.60% cpu_(10,101,101)_(2,2) 0.00050449 0.00016904 +66.49% cpu_(10,256,256)_(-1,-1) 0.00032043 5.6028e-05 +82.51% cpu_(10,256,256)_(-1,0) 0.001335 0.0004015 +69.93% cpu_(10,256,256)_(0,-1) 0.0013521 0.00038862 +71.26% cpu_(10,256,256)_(2,2) 0.001269 0.00039959 +68.51% cpu_(10,1000,1000)_(-1,-1) 0.0090729 6.3419e-05 +99.30% cpu_(10,1000,1000)_(-1,0) 0.01712 0.0047594 +72.20% cpu_(10,1000,1000)_(0,-1) 0.016647 0.0046474 +72.08% cpu_(10,1000,1000)_(2,2) 0.012737 0.0041161 +67.68% cpu_(10,1024,1024)_(-1,-1) 0.0093709 5.8889e-05 +99.37% cpu_(10,1024,1024)_(-1,0) 0.017075 0.0051999 +69.55% cpu_(10,1024,1024)_(0,-1) 0.016867 0.004617 +72.63% cpu_(10,1024,1024)_(2,2) 0.013191 0.003759 +71.50% cpu_(10,2048,2048)_(-1,-1) 0.028427 6.2466e-05 +99.78% cpu_(10,2048,2048)_(-1,0) 0.048134 0.017642 +63.35% cpu_(10,2048,2048)_(0,-1) 0.048773 0.017558 +64.00% cpu_(10,2048,2048)_(2,2) 0.036153 0.015452 +57.26% cpu_(10,10,4,4)_(-1,-1) 5.8055e-05 5.8055e-05 +0.00% cpu_(10,10,4,4)_(-1,0) 0.00015557 0.0001564 -0.54% cpu_(10,10,4,4)_(0,-1) 0.00015855 0.00015199 +4.14% cpu_(10,10,4,4)_(2,2) 0.00016379 0.00018096 -10.48% cpu_(10,10,10,10)_(-1,-1) 6.0558e-05 6.0558e-05 +0.00% cpu_(10,10,10,10)_(-1,0) 0.000368 0.00038695 -5.15% cpu_(10,10,10,10)_(0,-1) 0.00036263 0.00038612 -6.48% cpu_(10,10,10,10)_(2,2) 0.00038648 0.00042963 -11.17% cpu_(10,10,16,16)_(-1,-1) 6.9022e-05 5.7578e-05 +16.58% cpu_(10,10,16,16)_(-1,0) 0.0005815 0.0001874 +67.77% cpu_(10,10,16,16)_(0,-1) 0.00059354 0.0001924 +67.58% cpu_(10,10,16,16)_(2,2) 0.00062239 0.00019097 +69.32% cpu_(10,10,101,101)_(-1,-1) 0.00014806 6.2823e-05 +57.57% cpu_(10,10,101,101)_(-1,0) 0.0039785 0.00078249 +80.33% cpu_(10,10,101,101)_(0,-1) 0.0040585 0.00076556 +81.14% cpu_(10,10,101,101)_(2,2) 0.0039514 0.00077307 +80.44% cpu_(10,10,256,256)_(-1,-1) 0.0026824 6.0558e-05 +97.74% cpu_(10,10,256,256)_(-1,0) 0.017269 0.0031619 +81.69% cpu_(10,10,256,256)_(0,-1) 0.020287 0.0030774 +84.83% cpu_(10,10,256,256)_(2,2) 0.011919 0.0026599 +77.68% cpu_(10,10,1000,1000)_(-1,-1) 0.065783 5.6982e-05 +99.91% cpu_(10,10,1000,1000)_(-1,0) 0.1361 0.054533 +59.93% cpu_(10,10,1000,1000)_(0,-1) 0.1397 0.053405 +61.77% cpu_(10,10,1000,1000)_(2,2) 0.10173 0.048561 +52.26% cpu_(10,10,1024,1024)_(-1,-1) 0.066231 7.5579e-05 +99.89% cpu_(10,10,1024,1024)_(-1,0) 0.13615 0.059931 +55.98% cpu_(10,10,1024,1024)_(0,-1) 0.13745 0.064931 +52.76% cpu_(10,10,1024,1024)_(2,2) 0.10493 0.054258 +48.29% cpu_(10,10,2048,2048)_(-1,-1) 0.23487 6.6042e-05 +99.97% cpu_(10,10,2048,2048)_(-1,0) 0.41014 0.24283 +40.79% cpu_(10,10,2048,2048)_(0,-1) 0.43621 0.26393 +39.49% cpu_(10,10,2048,2048)_(2,2) 0.29919 0.22302 +25.46% gpu_(10,16,16)_(-1,-1) 0.00010753 0.00010753 +0.00% gpu_(10,16,16)_(-1,0) 0.00011253 0.00012445 -10.59% gpu_(10,16,16)_(0,-1) 0.00012493 0.00013399 -7.25% gpu_(10,16,16)_(2,2) 0.000108 0.00011754 -8.83% gpu_(10,101,101)_(-1,-1) 0.00011849 8.7976e-05 +25.75% gpu_(10,101,101)_(-1,0) 0.00012743 0.00012243 +3.93% gpu_(10,101,101)_(0,-1) 0.00012958 0.00012362 +4.60% gpu_(10,101,101)_(2,2) 0.00011504 0.00011504 +0.00% gpu_(10,256,256)_(-1,-1) 0.00013447 9.7513e-05 +27.48% gpu_(10,256,256)_(-1,0) 0.00018752 0.00014746 +21.36% gpu_(10,256,256)_(0,-1) 0.00017798 0.00016904 +5.02% gpu_(10,256,256)_(2,2) 0.0001514 0.00013697 +9.53% gpu_(10,1000,1000)_(-1,-1) 0.0005095 9.8586e-05 +80.65% gpu_(10,1000,1000)_(-1,0) 0.00088501 0.00056589 +36.06% gpu_(10,1000,1000)_(0,-1) 0.00090456 0.00055242 +38.93% gpu_(10,1000,1000)_(2,2) 0.00080955 0.00049639 +38.68% gpu_(10,1024,1024)_(-1,-1) 0.00050902 9.7036e-05 +80.94% gpu_(10,1024,1024)_(-1,0) 0.00098789 0.00058246 +41.04% gpu_(10,1024,1024)_(0,-1) 0.001 0.00059545 +40.46% gpu_(10,1024,1024)_(2,2) 0.00082254 0.00049961 +39.26% gpu_(10,2048,2048)_(-1,-1) 0.001495 9.8944e-05 +93.38% gpu_(10,2048,2048)_(-1,0) 0.003535 0.0017736 +49.83% gpu_(10,2048,2048)_(0,-1) 0.0034965 0.0017921 +48.75% gpu_(10,2048,2048)_(2,2) 0.0027704 0.0015399 +44.41% gpu_(10,10,4,4)_(-1,-1) 0.00011086 9.1076e-05 +17.85% gpu_(10,10,4,4)_(-1,0) 0.0001235 0.00013411 -8.59% gpu_(10,10,4,4)_(0,-1) 0.00011849 0.0001204 -1.61% gpu_(10,10,4,4)_(2,2) 0.00010896 0.00013256 -21.66% gpu_(10,10,10,10)_(-1,-1) 0.00010657 9.5844e-05 +10.07% gpu_(10,10,10,10)_(-1,0) 0.00011754 0.00013602 -15.72% gpu_(10,10,10,10)_(0,-1) 0.00011909 0.00012004 -0.80% gpu_(10,10,10,10)_(2,2) 0.00013196 0.00011349 +14.00% gpu_(10,10,16,16)_(-1,-1) 0.00012898 0.00010705 +17.01% gpu_(10,10,16,16)_(-1,0) 0.00014353 0.00012338 +14.04% gpu_(10,10,16,16)_(0,-1) 0.00011599 0.00012493 -7.71% gpu_(10,10,16,16)_(2,2) 0.00011539 0.00011349 +1.65% gpu_(10,10,101,101)_(-1,-1) 0.00014699 0.00010252 +30.25% gpu_(10,10,101,101)_(-1,0) 0.0002141 0.00015497 +27.62% gpu_(10,10,101,101)_(0,-1) 0.0002017 0.00015843 +21.45% gpu_(10,10,101,101)_(2,2) 0.00018394 0.00015402 +16.27% gpu_(10,10,256,256)_(-1,-1) 0.00032747 9.0003e-05 +72.52% gpu_(10,10,256,256)_(-1,0) 0.00074494 0.00040746 +45.30% gpu_(10,10,256,256)_(0,-1) 0.00072503 0.00042391 +41.53% gpu_(10,10,256,256)_(2,2) 0.00061846 0.00038004 +38.55% gpu_(10,10,1000,1000)_(-1,-1) 0.0032645 0.00010896 +96.66% gpu_(10,10,1000,1000)_(-1,0) 0.007543 0.0038971 +48.34% gpu_(10,10,1000,1000)_(0,-1) 0.006058 0.0039405 +34.95% gpu_(10,10,1000,1000)_(2,2) 0.005198 0.003448 +33.67% gpu_(10,10,1024,1024)_(-1,-1) 0.0034155 9.1434e-05 +97.32% gpu_(10,10,1024,1024)_(-1,0) 0.007099 0.004158 +41.43% gpu_(10,10,1024,1024)_(0,-1) 0.006843 0.003849 +43.75% gpu_(10,10,1024,1024)_(2,2) 0.005506 0.0031376 +43.02% gpu_(10,10,2048,2048)_(-1,-1) 0.013119 0.00010097 +99.23% gpu_(10,10,2048,2048)_(-1,0) 0.028533 0.015175 +46.81% gpu_(10,10,2048,2048)_(0,-1) 0.028458 0.014926 +47.55% gpu_(10,10,2048,2048)_(2,2) 0.022175 0.011797 +46.80% PiperOrigin-RevId: 168849471
Diffstat (limited to 'tensorflow/core/kernels/matrix_band_part_op.cc')
-rw-r--r--tensorflow/core/kernels/matrix_band_part_op.cc167
1 files changed, 126 insertions, 41 deletions
diff --git a/tensorflow/core/kernels/matrix_band_part_op.cc b/tensorflow/core/kernels/matrix_band_part_op.cc
index 894b0113c2..8b8accc0b3 100644
--- a/tensorflow/core/kernels/matrix_band_part_op.cc
+++ b/tensorflow/core/kernels/matrix_band_part_op.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/kernels/matrix_band_part_op.h"
+#include <algorithm>
#include <memory>
#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -32,6 +33,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
@@ -48,31 +50,50 @@ class MatrixBandPartOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
+ const TensorShape& input_shape = input.shape();
+ // Preliminary validation of sizes.
+ OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
+ errors::InvalidArgument(
+ "input must be at least 2-dim, received shape: ",
+ input.shape().DebugString()));
+ auto input_reshaped = input.flat_inner_dims<T, 3>();
+
const Tensor& num_lower_in = context->input(1);
OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_lower_in.shape()),
errors::InvalidArgument("num_lower must be scalar, got shape ",
num_lower_in.shape().DebugString()));
const int64 num_lower = num_lower_in.scalar<int64>()();
+ OP_REQUIRES(
+ context, num_lower <= input_reshaped.dimension(1),
+ errors::InvalidArgument(
+ "num_lower must be negative or less or equal to number of rows (",
+ input_reshaped.dimension(1), ") got: ", num_lower));
const Tensor& num_upper_in = context->input(2);
OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_upper_in.shape()),
errors::InvalidArgument("num_upper must be scalar, got shape ",
num_upper_in.shape().DebugString()));
const int64 num_upper = num_upper_in.scalar<int64>()();
+ OP_REQUIRES(context, num_upper <= input_reshaped.dimension(2),
+ errors::InvalidArgument("num_upper must be negative or less or "
+ "equal to number of columns (",
+ input_reshaped.dimension(2),
+ ") got: ", num_upper));
+
+ if ((num_lower < 0 || num_lower == input_reshaped.dimension(1)) &&
+ (num_upper < 0 || num_upper == input_reshaped.dimension(2))) {
+ // This is a no-op.
+ context->set_output(0, input);
+ return;
+ }
- const TensorShape& input_shape = input.shape();
- // Preliminary validation of sizes.
- OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
- errors::InvalidArgument(
- "input must be at least 2-dim, received shape: ",
- input.shape().DebugString()));
- auto input_reshaped = input.flat_inner_dims<T, 3>();
Tensor* output = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(0, input_shape, &output));
+ OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
+ {0}, 0, input_shape, &output));
auto output_reshaped = output->flat_inner_dims<T, 3>();
- functor::MatrixBandPart<Device, T>::Compute(
- context->eigen_device<Device>(), num_lower, num_upper, input_reshaped,
- output_reshaped);
+ functor::MatrixBandPartFunctor<Device, T> fn;
+ fn(context, context->eigen_device<Device>(), num_lower, num_upper,
+ false /* transpose */, input_reshaped, output_reshaped);
}
private:
@@ -98,54 +119,118 @@ TF_CALL_NUMBER_TYPES(REGISTER_BATCH_MATRIX_BAND_PART);
// Implementation of the functor specialization for CPU.
namespace functor {
-template <typename T>
-struct MatrixBandPart<CPUDevice, T> {
- static void Compute(const CPUDevice& d, int64 num_lower, int64 num_upper,
- typename TTypes<T, 3>::ConstTensor input,
- typename TTypes<T, 3>::Tensor output) {
- if ((num_lower < 0 || num_lower >= input.dimension(1)) &&
- (num_upper < 0 || num_upper >= input.dimension(2))) {
- output.device(d) = input;
+
+// CPU implementation of BandPartFunctor.
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+template <typename Scalar>
+struct MatrixBandPartFunctor<CPUDevice, Scalar> {
+ void operator()(OpKernelContext* context, const CPUDevice& device,
+ int num_lower_diags, int num_upper_diags, bool transpose,
+ typename TTypes<Scalar, 3>::ConstTensor input,
+ typename TTypes<Scalar, 3>::Tensor output) {
+ const int64 b = input.dimension(0);
+ const int64 m = input.dimension(1);
+ const int64 n = input.dimension(2);
+ auto thread_pool =
+ context->device()->tensorflow_cpu_worker_threads()->workers;
+ const int64 total_rows = b * m;
+ const int64 row_cost = 10 * n;
+ const bool in_place = input.data() == output.data();
+ CHECK(!(transpose && in_place));
+ if (!transpose) {
+ auto compute_shard = [=, &input, &output](int64 begin, int64 end) {
+ if (!in_place) {
+ std::fill(output.data() + begin * n, output.data() + end * n,
+ Scalar());
+ }
+ const int64 batch_begin = begin / m;
+ const int64 batch_end = (end + m - 1) / m;
+ for (int64 batch = batch_begin; batch < batch_end; ++batch) {
+ const int64 row_begin = begin > batch * m ? begin % m : 0;
+ const int64 row_end = end < (batch + 1) * m ? end % m : m;
+ for (int64 row = row_begin; row < row_end; ++row) {
+ const int64 band_start =
+ num_lower_diags < 0
+ ? 0
+ : std::min(n, std::max(0ll, row - num_lower_diags));
+ const int64 band_end = num_upper_diags < 0
+ ? n
+ : std::min(static_cast<int64>(n),
+ row + num_upper_diags + 1);
+ if (in_place) {
+ if (band_start > 0) {
+ std::fill(&output(batch, row, 0),
+ &output(batch, row, band_start), Scalar());
+ }
+ if (band_end < n) {
+ std::fill(&output(batch, row, band_end), &output(batch, row, n),
+ Scalar());
+ }
+ } else {
+ if (band_start < band_end) {
+ const Eigen::DSizes<Eigen::DenseIndex, 3> indices(batch, row,
+ band_start);
+ const Eigen::DSizes<Eigen::DenseIndex, 3> sizes(
+ 1, 1, band_end - band_start);
+ output.slice(indices, sizes) = input.slice(indices, sizes);
+ }
+ }
+ }
+ }
+ };
+ thread_pool->ParallelFor(total_rows, row_cost, std::move(compute_shard));
} else {
- output.device(d) = output.constant(T());
- for (int64 r = 0; r < output.dimension(0); ++r) {
- for (int64 i = 0; i < output.dimension(1); ++i) {
- const int64 band_start =
- num_lower < 0 ? 0 : std::max(0ll, i - num_lower);
- const int64 band_end =
- num_upper < 0 ? output.dimension(2)
- : std::min(static_cast<int64>(output.dimension(2)),
- i + num_upper + 1);
- if (band_start < band_end) {
- const Eigen::DSizes<Eigen::DenseIndex, 3> indices(r, i, band_start);
- const Eigen::DSizes<Eigen::DenseIndex, 3> sizes(
- 1, 1, band_end - band_start);
- output.slice(indices, sizes) = input.slice(indices, sizes);
+ output.device(device) = output.constant(Scalar());
+ auto compute_shard = [=, &input, &output](int64 begin, int64 end) {
+ const int64 batch_begin = begin / m;
+ const int64 batch_end = (end + m - 1) / m;
+ for (int64 batch = batch_begin; batch < batch_end; ++batch) {
+ const int64 row_begin = begin > batch * m ? begin % m : 0;
+ const int64 row_end = end < (batch + 1) * m ? end % m : m;
+ for (int64 row = row_begin; row < row_end; ++row) {
+ const int64 band_start =
+ num_lower_diags < 0 ? 0 : std::max(0ll, row - num_lower_diags);
+ const int64 band_end = num_upper_diags < 0
+ ? n
+ : std::min(static_cast<int64>(n),
+ row + num_upper_diags + 1);
+ for (int64 col = band_start; col < band_end; ++col) {
+ output(batch, col, row) = input(batch, row, col);
+ }
}
}
- }
+ };
+ thread_pool->ParallelFor(total_rows, row_cost, std::move(compute_shard));
}
}
};
+#define DEFINE_CPU_SPEC(T) template struct MatrixBandPartFunctor<CPUDevice, T>;
+TF_CALL_POD_TYPES(DEFINE_CPU_SPEC);
+#undef DEFINE_CPU_SPEC
+
} // namespace functor
#if GOOGLE_CUDA
// Forward declarations of the functor specializations for GPU.
namespace functor {
-#define DECLARE_GPU_SPEC(T) \
- template <> \
- void MatrixBandPart<GPUDevice, T>::Compute( \
- const GPUDevice& d, Eigen::DenseIndex num_lower, \
- Eigen::DenseIndex num_upper, typename TTypes<T, 3>::ConstTensor input, \
- typename TTypes<T, 3>::Tensor output); \
- extern template struct MatrixBandPart<GPUDevice, T>;
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ struct MatrixBandPartFunctor<GPUDevice, T> { \
+ void operator()(OpKernelContext* context, const GPUDevice& device, \
+ int num_upper_diags, int num_lower_diags, bool transpose, \
+ typename TTypes<T, 3>::ConstTensor input, \
+ typename TTypes<T, 3>::Tensor output); \
+ }; \
+ extern template struct MatrixBandPartFunctor<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
TF_CALL_bool(DECLARE_GPU_SPEC);
TF_CALL_complex64(DECLARE_GPU_SPEC);
TF_CALL_complex128(DECLARE_GPU_SPEC);
+#undef DECLARE_GPU_SPEC
} // namespace functor
// Registration of the GPU implementations.