aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-15 10:05:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-15 10:08:44 -0700
commit2835ebaa9fb4e38c01c165da1b6dd80a6250fd3f (patch)
treeb2248666a2526eca5241d58f5b6d827ffb7b290a /tensorflow
parent6d51dd66d1fd938fa0f95f5933169aaccd6aef76 (diff)
Optimize C++ kernels for the matrix_band_part op, which is used in various ops operating on triangular or banded matrices:
* Add benchmark for matrix_band_part. * Implement simple optimized CUDA kernel instead of calling Eigen generator. * Parallelize CPU kernel for matrix_band_part. * Support on-the-fly transposition in the underlying functors (to be used for future QR op in followup). Benchmarks: First column is of the form {device}_{shape}_{num_lower,num_upper} Test case Before After Speedup cpu_(10,16,16)_(-1,-1) 5.6505e-05 6.2108e-05 -9.92% cpu_(10,16,16)_(-1,0) 0.00010848 0.00010908 -0.55% cpu_(10,16,16)_(0,-1) 0.0001055 0.00011396 -8.02% cpu_(10,16,16)_(2,2) 0.000108 0.00011706 -8.39% cpu_(10,101,101)_(-1,-1) 0.00013697 6.0558e-05 +55.79% cpu_(10,101,101)_(-1,0) 0.00054002 0.00017703 +67.22% cpu_(10,101,101)_(0,-1) 0.00051188 0.00017607 +65.60% cpu_(10,101,101)_(2,2) 0.00050449 0.00016904 +66.49% cpu_(10,256,256)_(-1,-1) 0.00032043 5.6028e-05 +82.51% cpu_(10,256,256)_(-1,0) 0.001335 0.0004015 +69.93% cpu_(10,256,256)_(0,-1) 0.0013521 0.00038862 +71.26% cpu_(10,256,256)_(2,2) 0.001269 0.00039959 +68.51% cpu_(10,1000,1000)_(-1,-1) 0.0090729 6.3419e-05 +99.30% cpu_(10,1000,1000)_(-1,0) 0.01712 0.0047594 +72.20% cpu_(10,1000,1000)_(0,-1) 0.016647 0.0046474 +72.08% cpu_(10,1000,1000)_(2,2) 0.012737 0.0041161 +67.68% cpu_(10,1024,1024)_(-1,-1) 0.0093709 5.8889e-05 +99.37% cpu_(10,1024,1024)_(-1,0) 0.017075 0.0051999 +69.55% cpu_(10,1024,1024)_(0,-1) 0.016867 0.004617 +72.63% cpu_(10,1024,1024)_(2,2) 0.013191 0.003759 +71.50% cpu_(10,2048,2048)_(-1,-1) 0.028427 6.2466e-05 +99.78% cpu_(10,2048,2048)_(-1,0) 0.048134 0.017642 +63.35% cpu_(10,2048,2048)_(0,-1) 0.048773 0.017558 +64.00% cpu_(10,2048,2048)_(2,2) 0.036153 0.015452 +57.26% cpu_(10,10,4,4)_(-1,-1) 5.8055e-05 5.8055e-05 +0.00% cpu_(10,10,4,4)_(-1,0) 0.00015557 0.0001564 -0.54% cpu_(10,10,4,4)_(0,-1) 0.00015855 0.00015199 +4.14% cpu_(10,10,4,4)_(2,2) 0.00016379 0.00018096 -10.48% cpu_(10,10,10,10)_(-1,-1) 6.0558e-05 6.0558e-05 +0.00% cpu_(10,10,10,10)_(-1,0) 0.000368 0.00038695 -5.15% cpu_(10,10,10,10)_(0,-1) 0.00036263 0.00038612 -6.48% cpu_(10,10,10,10)_(2,2) 0.00038648 0.00042963 -11.17% cpu_(10,10,16,16)_(-1,-1) 6.9022e-05 5.7578e-05 +16.58% cpu_(10,10,16,16)_(-1,0) 0.0005815 0.0001874 +67.77% cpu_(10,10,16,16)_(0,-1) 0.00059354 0.0001924 +67.58% cpu_(10,10,16,16)_(2,2) 0.00062239 0.00019097 +69.32% cpu_(10,10,101,101)_(-1,-1) 0.00014806 6.2823e-05 +57.57% cpu_(10,10,101,101)_(-1,0) 0.0039785 0.00078249 +80.33% cpu_(10,10,101,101)_(0,-1) 0.0040585 0.00076556 +81.14% cpu_(10,10,101,101)_(2,2) 0.0039514 0.00077307 +80.44% cpu_(10,10,256,256)_(-1,-1) 0.0026824 6.0558e-05 +97.74% cpu_(10,10,256,256)_(-1,0) 0.017269 0.0031619 +81.69% cpu_(10,10,256,256)_(0,-1) 0.020287 0.0030774 +84.83% cpu_(10,10,256,256)_(2,2) 0.011919 0.0026599 +77.68% cpu_(10,10,1000,1000)_(-1,-1) 0.065783 5.6982e-05 +99.91% cpu_(10,10,1000,1000)_(-1,0) 0.1361 0.054533 +59.93% cpu_(10,10,1000,1000)_(0,-1) 0.1397 0.053405 +61.77% cpu_(10,10,1000,1000)_(2,2) 0.10173 0.048561 +52.26% cpu_(10,10,1024,1024)_(-1,-1) 0.066231 7.5579e-05 +99.89% cpu_(10,10,1024,1024)_(-1,0) 0.13615 0.059931 +55.98% cpu_(10,10,1024,1024)_(0,-1) 0.13745 0.064931 +52.76% cpu_(10,10,1024,1024)_(2,2) 0.10493 0.054258 +48.29% cpu_(10,10,2048,2048)_(-1,-1) 0.23487 6.6042e-05 +99.97% cpu_(10,10,2048,2048)_(-1,0) 0.41014 0.24283 +40.79% cpu_(10,10,2048,2048)_(0,-1) 0.43621 0.26393 +39.49% cpu_(10,10,2048,2048)_(2,2) 0.29919 0.22302 +25.46% gpu_(10,16,16)_(-1,-1) 0.00010753 0.00010753 +0.00% gpu_(10,16,16)_(-1,0) 0.00011253 0.00012445 -10.59% gpu_(10,16,16)_(0,-1) 0.00012493 0.00013399 -7.25% gpu_(10,16,16)_(2,2) 0.000108 0.00011754 -8.83% gpu_(10,101,101)_(-1,-1) 0.00011849 8.7976e-05 +25.75% gpu_(10,101,101)_(-1,0) 0.00012743 0.00012243 +3.93% gpu_(10,101,101)_(0,-1) 0.00012958 0.00012362 +4.60% gpu_(10,101,101)_(2,2) 0.00011504 0.00011504 +0.00% gpu_(10,256,256)_(-1,-1) 0.00013447 9.7513e-05 +27.48% gpu_(10,256,256)_(-1,0) 0.00018752 0.00014746 +21.36% gpu_(10,256,256)_(0,-1) 0.00017798 0.00016904 +5.02% gpu_(10,256,256)_(2,2) 0.0001514 0.00013697 +9.53% gpu_(10,1000,1000)_(-1,-1) 0.0005095 9.8586e-05 +80.65% gpu_(10,1000,1000)_(-1,0) 0.00088501 0.00056589 +36.06% gpu_(10,1000,1000)_(0,-1) 0.00090456 0.00055242 +38.93% gpu_(10,1000,1000)_(2,2) 0.00080955 0.00049639 +38.68% gpu_(10,1024,1024)_(-1,-1) 0.00050902 9.7036e-05 +80.94% gpu_(10,1024,1024)_(-1,0) 0.00098789 0.00058246 +41.04% gpu_(10,1024,1024)_(0,-1) 0.001 0.00059545 +40.46% gpu_(10,1024,1024)_(2,2) 0.00082254 0.00049961 +39.26% gpu_(10,2048,2048)_(-1,-1) 0.001495 9.8944e-05 +93.38% gpu_(10,2048,2048)_(-1,0) 0.003535 0.0017736 +49.83% gpu_(10,2048,2048)_(0,-1) 0.0034965 0.0017921 +48.75% gpu_(10,2048,2048)_(2,2) 0.0027704 0.0015399 +44.41% gpu_(10,10,4,4)_(-1,-1) 0.00011086 9.1076e-05 +17.85% gpu_(10,10,4,4)_(-1,0) 0.0001235 0.00013411 -8.59% gpu_(10,10,4,4)_(0,-1) 0.00011849 0.0001204 -1.61% gpu_(10,10,4,4)_(2,2) 0.00010896 0.00013256 -21.66% gpu_(10,10,10,10)_(-1,-1) 0.00010657 9.5844e-05 +10.07% gpu_(10,10,10,10)_(-1,0) 0.00011754 0.00013602 -15.72% gpu_(10,10,10,10)_(0,-1) 0.00011909 0.00012004 -0.80% gpu_(10,10,10,10)_(2,2) 0.00013196 0.00011349 +14.00% gpu_(10,10,16,16)_(-1,-1) 0.00012898 0.00010705 +17.01% gpu_(10,10,16,16)_(-1,0) 0.00014353 0.00012338 +14.04% gpu_(10,10,16,16)_(0,-1) 0.00011599 0.00012493 -7.71% gpu_(10,10,16,16)_(2,2) 0.00011539 0.00011349 +1.65% gpu_(10,10,101,101)_(-1,-1) 0.00014699 0.00010252 +30.25% gpu_(10,10,101,101)_(-1,0) 0.0002141 0.00015497 +27.62% gpu_(10,10,101,101)_(0,-1) 0.0002017 0.00015843 +21.45% gpu_(10,10,101,101)_(2,2) 0.00018394 0.00015402 +16.27% gpu_(10,10,256,256)_(-1,-1) 0.00032747 9.0003e-05 +72.52% gpu_(10,10,256,256)_(-1,0) 0.00074494 0.00040746 +45.30% gpu_(10,10,256,256)_(0,-1) 0.00072503 0.00042391 +41.53% gpu_(10,10,256,256)_(2,2) 0.00061846 0.00038004 +38.55% gpu_(10,10,1000,1000)_(-1,-1) 0.0032645 0.00010896 +96.66% gpu_(10,10,1000,1000)_(-1,0) 0.007543 0.0038971 +48.34% gpu_(10,10,1000,1000)_(0,-1) 0.006058 0.0039405 +34.95% gpu_(10,10,1000,1000)_(2,2) 0.005198 0.003448 +33.67% gpu_(10,10,1024,1024)_(-1,-1) 0.0034155 9.1434e-05 +97.32% gpu_(10,10,1024,1024)_(-1,0) 0.007099 0.004158 +41.43% gpu_(10,10,1024,1024)_(0,-1) 0.006843 0.003849 +43.75% gpu_(10,10,1024,1024)_(2,2) 0.005506 0.0031376 +43.02% gpu_(10,10,2048,2048)_(-1,-1) 0.013119 0.00010097 +99.23% gpu_(10,10,2048,2048)_(-1,0) 0.028533 0.015175 +46.81% gpu_(10,10,2048,2048)_(0,-1) 0.028458 0.014926 +47.55% gpu_(10,10,2048,2048)_(2,2) 0.022175 0.011797 +46.80% PiperOrigin-RevId: 168849471
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/core/kernels/BUILD4
-rw-r--r--tensorflow/core/kernels/cholesky_op.cc23
-rw-r--r--tensorflow/core/kernels/matrix_band_part_op.cc167
-rw-r--r--tensorflow/core/kernels/matrix_band_part_op.h53
-rw-r--r--tensorflow/core/kernels/matrix_band_part_op_gpu.cu.cc80
-rw-r--r--tensorflow/python/kernel_tests/matrix_band_part_op_test.py120
6 files changed, 318 insertions, 129 deletions
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 9a1e96a131..8bcbe0dd41 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -665,7 +665,9 @@ tf_kernel_library(
tf_kernel_library(
name = "matrix_band_part_op",
prefix = "matrix_band_part_op",
- deps = ARRAY_DEPS,
+ deps = if_cuda([
+ ":cuda_solvers",
+ ]) + ARRAY_DEPS,
)
tf_kernel_library(
diff --git a/tensorflow/core/kernels/cholesky_op.cc b/tensorflow/core/kernels/cholesky_op.cc
index 755ce7c43b..6668b0d654 100644
--- a/tensorflow/core/kernels/cholesky_op.cc
+++ b/tensorflow/core/kernels/cholesky_op.cc
@@ -76,18 +76,19 @@ class CholeskyOp : public LinearAlgebraOp<Scalar> {
typedef Eigen::GpuDevice GPUDevice;
namespace functor {
-#define DECLARE_GPU_SPEC(T) \
- template <> \
- void MatrixBandPart<GPUDevice, T>::Compute( \
- const GPUDevice& d, Eigen::DenseIndex num_lower, \
- Eigen::DenseIndex num_upper, typename TTypes<T, 3>::ConstTensor input, \
- typename TTypes<T, 3>::Tensor output); \
- extern template struct MatrixBandPart<GPUDevice, T>;
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ struct MatrixBandPartFunctor<GPUDevice, T> { \
+ void operator()(OpKernelContext* context, const GPUDevice& device, \
+ int num_upper_diags, int num_lower_diags, bool transpose, \
+ typename TTypes<T, 3>::ConstTensor input, \
+ typename TTypes<T, 3>::Tensor output); \
+ }; \
+ extern template struct MatrixBandPartFunctor<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
TF_CALL_complex64(DECLARE_GPU_SPEC);
TF_CALL_complex128(DECLARE_GPU_SPEC);
-
} // namespace functor
template <class Scalar>
@@ -131,9 +132,9 @@ class CholeskyOpGpu : public AsyncOpKernel {
// before we launch each of the Cholesky factorization kernels in paralle.
auto input_reshaped = input.template flat_inner_dims<Scalar, 3>();
auto output_reshaped = output->template flat_inner_dims<Scalar, 3>();
- functor::MatrixBandPart<GPUDevice, Scalar>::Compute(
- context->eigen_device<GPUDevice>(), n, 0, input_reshaped,
- output_reshaped);
+ functor::MatrixBandPartFunctor<GPUDevice, Scalar> fn;
+ fn(context, context->eigen_device<GPUDevice>(), n, 0, false /* transpose */,
+ input_reshaped, output_reshaped);
// Launch a Cholesky kernel for each matrix in the batch.
const int64 batch_size = input_reshaped.dimension(0);
diff --git a/tensorflow/core/kernels/matrix_band_part_op.cc b/tensorflow/core/kernels/matrix_band_part_op.cc
index 894b0113c2..8b8accc0b3 100644
--- a/tensorflow/core/kernels/matrix_band_part_op.cc
+++ b/tensorflow/core/kernels/matrix_band_part_op.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/kernels/matrix_band_part_op.h"
+#include <algorithm>
#include <memory>
#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -32,6 +33,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
@@ -48,31 +50,50 @@ class MatrixBandPartOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
+ const TensorShape& input_shape = input.shape();
+ // Preliminary validation of sizes.
+ OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
+ errors::InvalidArgument(
+ "input must be at least 2-dim, received shape: ",
+ input.shape().DebugString()));
+ auto input_reshaped = input.flat_inner_dims<T, 3>();
+
const Tensor& num_lower_in = context->input(1);
OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_lower_in.shape()),
errors::InvalidArgument("num_lower must be scalar, got shape ",
num_lower_in.shape().DebugString()));
const int64 num_lower = num_lower_in.scalar<int64>()();
+ OP_REQUIRES(
+ context, num_lower <= input_reshaped.dimension(1),
+ errors::InvalidArgument(
+ "num_lower must be negative or less or equal to number of rows (",
+ input_reshaped.dimension(1), ") got: ", num_lower));
const Tensor& num_upper_in = context->input(2);
OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_upper_in.shape()),
errors::InvalidArgument("num_upper must be scalar, got shape ",
num_upper_in.shape().DebugString()));
const int64 num_upper = num_upper_in.scalar<int64>()();
+ OP_REQUIRES(context, num_upper <= input_reshaped.dimension(2),
+ errors::InvalidArgument("num_upper must be negative or less or "
+ "equal to number of columns (",
+ input_reshaped.dimension(2),
+ ") got: ", num_upper));
+
+ if ((num_lower < 0 || num_lower == input_reshaped.dimension(1)) &&
+ (num_upper < 0 || num_upper == input_reshaped.dimension(2))) {
+ // This is a no-op.
+ context->set_output(0, input);
+ return;
+ }
- const TensorShape& input_shape = input.shape();
- // Preliminary validation of sizes.
- OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
- errors::InvalidArgument(
- "input must be at least 2-dim, received shape: ",
- input.shape().DebugString()));
- auto input_reshaped = input.flat_inner_dims<T, 3>();
Tensor* output = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(0, input_shape, &output));
+ OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
+ {0}, 0, input_shape, &output));
auto output_reshaped = output->flat_inner_dims<T, 3>();
- functor::MatrixBandPart<Device, T>::Compute(
- context->eigen_device<Device>(), num_lower, num_upper, input_reshaped,
- output_reshaped);
+ functor::MatrixBandPartFunctor<Device, T> fn;
+ fn(context, context->eigen_device<Device>(), num_lower, num_upper,
+ false /* transpose */, input_reshaped, output_reshaped);
}
private:
@@ -98,54 +119,118 @@ TF_CALL_NUMBER_TYPES(REGISTER_BATCH_MATRIX_BAND_PART);
// Implementation of the functor specialization for CPU.
namespace functor {
-template <typename T>
-struct MatrixBandPart<CPUDevice, T> {
- static void Compute(const CPUDevice& d, int64 num_lower, int64 num_upper,
- typename TTypes<T, 3>::ConstTensor input,
- typename TTypes<T, 3>::Tensor output) {
- if ((num_lower < 0 || num_lower >= input.dimension(1)) &&
- (num_upper < 0 || num_upper >= input.dimension(2))) {
- output.device(d) = input;
+
+// CPU implementation of BandPartFunctor.
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+template <typename Scalar>
+struct MatrixBandPartFunctor<CPUDevice, Scalar> {
+ void operator()(OpKernelContext* context, const CPUDevice& device,
+ int num_lower_diags, int num_upper_diags, bool transpose,
+ typename TTypes<Scalar, 3>::ConstTensor input,
+ typename TTypes<Scalar, 3>::Tensor output) {
+ const int64 b = input.dimension(0);
+ const int64 m = input.dimension(1);
+ const int64 n = input.dimension(2);
+ auto thread_pool =
+ context->device()->tensorflow_cpu_worker_threads()->workers;
+ const int64 total_rows = b * m;
+ const int64 row_cost = 10 * n;
+ const bool in_place = input.data() == output.data();
+ CHECK(!(transpose && in_place));
+ if (!transpose) {
+ auto compute_shard = [=, &input, &output](int64 begin, int64 end) {
+ if (!in_place) {
+ std::fill(output.data() + begin * n, output.data() + end * n,
+ Scalar());
+ }
+ const int64 batch_begin = begin / m;
+ const int64 batch_end = (end + m - 1) / m;
+ for (int64 batch = batch_begin; batch < batch_end; ++batch) {
+ const int64 row_begin = begin > batch * m ? begin % m : 0;
+ const int64 row_end = end < (batch + 1) * m ? end % m : m;
+ for (int64 row = row_begin; row < row_end; ++row) {
+ const int64 band_start =
+ num_lower_diags < 0
+ ? 0
+ : std::min(n, std::max(0ll, row - num_lower_diags));
+ const int64 band_end = num_upper_diags < 0
+ ? n
+ : std::min(static_cast<int64>(n),
+ row + num_upper_diags + 1);
+ if (in_place) {
+ if (band_start > 0) {
+ std::fill(&output(batch, row, 0),
+ &output(batch, row, band_start), Scalar());
+ }
+ if (band_end < n) {
+ std::fill(&output(batch, row, band_end), &output(batch, row, n),
+ Scalar());
+ }
+ } else {
+ if (band_start < band_end) {
+ const Eigen::DSizes<Eigen::DenseIndex, 3> indices(batch, row,
+ band_start);
+ const Eigen::DSizes<Eigen::DenseIndex, 3> sizes(
+ 1, 1, band_end - band_start);
+ output.slice(indices, sizes) = input.slice(indices, sizes);
+ }
+ }
+ }
+ }
+ };
+ thread_pool->ParallelFor(total_rows, row_cost, std::move(compute_shard));
} else {
- output.device(d) = output.constant(T());
- for (int64 r = 0; r < output.dimension(0); ++r) {
- for (int64 i = 0; i < output.dimension(1); ++i) {
- const int64 band_start =
- num_lower < 0 ? 0 : std::max(0ll, i - num_lower);
- const int64 band_end =
- num_upper < 0 ? output.dimension(2)
- : std::min(static_cast<int64>(output.dimension(2)),
- i + num_upper + 1);
- if (band_start < band_end) {
- const Eigen::DSizes<Eigen::DenseIndex, 3> indices(r, i, band_start);
- const Eigen::DSizes<Eigen::DenseIndex, 3> sizes(
- 1, 1, band_end - band_start);
- output.slice(indices, sizes) = input.slice(indices, sizes);
+ output.device(device) = output.constant(Scalar());
+ auto compute_shard = [=, &input, &output](int64 begin, int64 end) {
+ const int64 batch_begin = begin / m;
+ const int64 batch_end = (end + m - 1) / m;
+ for (int64 batch = batch_begin; batch < batch_end; ++batch) {
+ const int64 row_begin = begin > batch * m ? begin % m : 0;
+ const int64 row_end = end < (batch + 1) * m ? end % m : m;
+ for (int64 row = row_begin; row < row_end; ++row) {
+ const int64 band_start =
+ num_lower_diags < 0 ? 0 : std::max(0ll, row - num_lower_diags);
+ const int64 band_end = num_upper_diags < 0
+ ? n
+ : std::min(static_cast<int64>(n),
+ row + num_upper_diags + 1);
+ for (int64 col = band_start; col < band_end; ++col) {
+ output(batch, col, row) = input(batch, row, col);
+ }
}
}
- }
+ };
+ thread_pool->ParallelFor(total_rows, row_cost, std::move(compute_shard));
}
}
};
+#define DEFINE_CPU_SPEC(T) template struct MatrixBandPartFunctor<CPUDevice, T>;
+TF_CALL_POD_TYPES(DEFINE_CPU_SPEC);
+#undef DEFINE_CPU_SPEC
+
} // namespace functor
#if GOOGLE_CUDA
// Forward declarations of the functor specializations for GPU.
namespace functor {
-#define DECLARE_GPU_SPEC(T) \
- template <> \
- void MatrixBandPart<GPUDevice, T>::Compute( \
- const GPUDevice& d, Eigen::DenseIndex num_lower, \
- Eigen::DenseIndex num_upper, typename TTypes<T, 3>::ConstTensor input, \
- typename TTypes<T, 3>::Tensor output); \
- extern template struct MatrixBandPart<GPUDevice, T>;
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ struct MatrixBandPartFunctor<GPUDevice, T> { \
+ void operator()(OpKernelContext* context, const GPUDevice& device, \
+ int num_upper_diags, int num_lower_diags, bool transpose, \
+ typename TTypes<T, 3>::ConstTensor input, \
+ typename TTypes<T, 3>::Tensor output); \
+ }; \
+ extern template struct MatrixBandPartFunctor<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
TF_CALL_bool(DECLARE_GPU_SPEC);
TF_CALL_complex64(DECLARE_GPU_SPEC);
TF_CALL_complex128(DECLARE_GPU_SPEC);
+#undef DECLARE_GPU_SPEC
} // namespace functor
// Registration of the GPU implementations.
diff --git a/tensorflow/core/kernels/matrix_band_part_op.h b/tensorflow/core/kernels/matrix_band_part_op.h
index b601255b25..43b6724dae 100644
--- a/tensorflow/core/kernels/matrix_band_part_op.h
+++ b/tensorflow/core/kernels/matrix_band_part_op.h
@@ -16,61 +16,22 @@ limitations under the License.
#ifndef TENSORFLOW_KERNELS_MATRIX_DIAG_OP_H_
#define TENSORFLOW_KERNELS_MATRIX_DIAG_OP_H_
-// Generator definition for MatrixBandPartOp, must be compilable by nvcc.
-
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
-
-namespace generator {
-
-template <typename T>
-class MatrixBandPartGenerator {
- public:
- EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE MatrixBandPartGenerator(
- Eigen::DenseIndex num_lower, Eigen::DenseIndex num_upper,
- typename TTypes<T, 3>::ConstTensor input)
- : num_lower_(num_lower), num_upper_(num_upper), input_(input) {}
-
- EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
- operator()(const Eigen::array<Eigen::DenseIndex, 3>& coords) const {
- return (((num_lower_ < 0 || coords[1] - coords[2] <= num_lower_) &&
- (num_upper_ < 0 || coords[2] - coords[1] <= num_upper_))
- ? input_(coords)
- : T());
- }
-
- private:
- const Eigen::DenseIndex num_lower_;
- const Eigen::DenseIndex num_upper_;
- typename TTypes<T, 3>::ConstTensor input_;
-};
-
-} // namespace generator
-
namespace functor {
-template <typename Device, typename T>
-struct MatrixBandPart {
- EIGEN_ALWAYS_INLINE static void Compute(
- const Device& d, Eigen::DenseIndex num_lower, Eigen::DenseIndex num_upper,
- typename TTypes<T, 3>::ConstTensor input,
- typename TTypes<T, 3>::Tensor output) {
- if ((num_lower < 0 || num_lower >= input.dimension(1)) &&
- (num_upper < 0 || num_upper >= input.dimension(2))) {
- output.device(d) = input;
- } else {
- generator::MatrixBandPartGenerator<T> generator(num_lower, num_upper,
- input);
- output.device(d) = output.generate(generator);
- }
- }
+template <typename Device, typename Scalar>
+struct MatrixBandPartFunctor {
+ void operator()(OpKernelContext* context, const Device& device,
+ int num_upper_diags, int num_lower_diags, bool transpose,
+ typename TTypes<Scalar, 3>::ConstTensor input,
+ typename TTypes<Scalar, 3>::Tensor output);
};
} // namespace functor
-
} // namespace tensorflow
#endif // TENSORFLOW_KERNELS_MATRIX_DIAG_OP_H_
diff --git a/tensorflow/core/kernels/matrix_band_part_op_gpu.cu.cc b/tensorflow/core/kernels/matrix_band_part_op_gpu.cu.cc
index ccc10ebada..afebdacdca 100644
--- a/tensorflow/core/kernels/matrix_band_part_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/matrix_band_part_op_gpu.cu.cc
@@ -17,22 +17,92 @@ limitations under the License.
#define EIGEN_USE_GPU
+#include <complex>
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/matrix_band_part_op.h"
+#include "tensorflow/core/util/cuda_kernel_helper.h"
namespace tensorflow {
-
+namespace functor {
typedef Eigen::GpuDevice GPUDevice;
-#define DEFINE_GPU_SPEC(T) \
- template class generator::MatrixBandPartGenerator<T>; \
- template struct functor::MatrixBandPart<GPUDevice, T>;
+template <bool transpose, typename Scalar>
+__global__ void MatrixBandPartKernel(const int num_threads,
+ const int batch_size, const int m,
+ const int n, const int num_lower_diags,
+ const int num_upper_diags,
+ const Scalar* input_ptr,
+ Scalar* output_ptr) {
+ if (!transpose) {
+ CUDA_1D_KERNEL_LOOP(index, num_threads) {
+ const int col = index % n;
+ const int row = (index / n) % m;
+ const int band_start = (num_lower_diags < 0 ? 0 : row - num_lower_diags);
+ const int band_end =
+ (num_upper_diags < 0 ? n : row + num_upper_diags + 1);
+ if (col < band_start || col >= band_end) {
+ output_ptr[index] = Scalar();
+ } else {
+ output_ptr[index] = input_ptr[index];
+ }
+ }
+ } else {
+ const int matrix_size = m * n;
+ CUDA_1D_KERNEL_LOOP(index, num_threads) {
+ const int col = index % n;
+ const int row = (index / n) % m;
+ const int batch = index / matrix_size;
+ const int transpose_index = batch * matrix_size + n * col + row;
+ const int band_start = (num_lower_diags < 0 ? 0 : row - num_lower_diags);
+ const int band_end =
+ (num_upper_diags < 0 ? n : row + num_upper_diags + 1);
+ if (col < band_start || col >= band_end) {
+ output_ptr[transpose_index] = Scalar();
+ } else {
+ output_ptr[transpose_index] = input_ptr[index];
+ }
+ }
+ }
+}
+
+template <typename Scalar>
+struct MatrixBandPartFunctor<GPUDevice, Scalar> {
+ void operator()(OpKernelContext* context, const GPUDevice& device,
+ int num_lower_diags, int num_upper_diags, bool transpose,
+ typename TTypes<Scalar, 3>::ConstTensor input,
+ typename TTypes<Scalar, 3>::Tensor output) {
+ using CudaType = typename CUDAComplexT<Scalar>::type;
+ const int batch_size = input.dimension(0);
+ const int m = input.dimension(1);
+ const int n = input.dimension(2);
+ const CudaType* input_ptr = reinterpret_cast<const CudaType*>(input.data());
+ CudaType* output_ptr = reinterpret_cast<CudaType*>(output.data());
+ CudaLaunchConfig config = GetCudaLaunchConfig(batch_size * m * n, device);
+ if (transpose) {
+ MatrixBandPartKernel<true>
+ <<<config.block_count, config.thread_per_block, 0, device.stream()>>>(
+ config.virtual_thread_count, batch_size, m, n, num_lower_diags,
+ num_upper_diags, input_ptr, output_ptr);
+ } else {
+ MatrixBandPartKernel<false>
+ <<<config.block_count, config.thread_per_block, 0, device.stream()>>>(
+ config.virtual_thread_count, batch_size, m, n, num_lower_diags,
+ num_upper_diags, input_ptr, output_ptr);
+ }
+ }
+};
+
+#define DEFINE_GPU_SPEC(T) template struct MatrixBandPartFunctor<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC);
TF_CALL_bool(DEFINE_GPU_SPEC);
TF_CALL_complex64(DEFINE_GPU_SPEC);
TF_CALL_complex128(DEFINE_GPU_SPEC);
-} // end namespace tensorflow
+#undef DEFINE_GPU_SPEC
+} // namespace functor
+} // namespace tensorflow
#endif // GOOGLE_CUDA
diff --git a/tensorflow/python/kernel_tests/matrix_band_part_op_test.py b/tensorflow/python/kernel_tests/matrix_band_part_op_test.py
index e641d5511f..317b8dc05b 100644
--- a/tensorflow/python/kernel_tests/matrix_band_part_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_band_part_op_test.py
@@ -19,13 +19,24 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradient_checker
-from tensorflow.python.platform import test
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test as test_lib
-class MatrixBandPartTest(test.TestCase):
+def _AddTest(test, op_name, testcase_name, fn):
+ test_name = "_".join(["test", op_name, testcase_name])
+ if hasattr(test, test_name):
+ raise RuntimeError("Test %s defined more than once" % test_name)
+ setattr(test, test_name, fn)
+
+
+class MatrixBandPartTest(test_lib.TestCase):
pass # Filled in below
@@ -34,23 +45,23 @@ def _GetMatrixBandPartTest(dtype_, batch_shape_, shape_):
def Test(self):
mat = np.ones(shape_).astype(dtype_)
batch_mat = np.tile(mat, batch_shape_ + (1, 1))
- with self.test_session(use_gpu=True):
- for lower in -1, 0, 1, shape_[-2] - 1:
- for upper in -1, 0, 1, shape_[-1] - 1:
- band_np = mat
- if lower >= 0:
- band_np = np.triu(band_np, -lower)
- if upper >= 0:
- band_np = np.tril(band_np, upper)
- if batch_shape_ is not ():
- band_np = np.tile(band_np, batch_shape + (1, 1))
+ for lower in -1, 0, 1, shape_[-2] - 1:
+ for upper in -1, 0, 1, shape_[-1] - 1:
+ band_np = mat
+ if lower >= 0:
+ band_np = np.triu(band_np, -lower)
+ if upper >= 0:
+ band_np = np.tril(band_np, upper)
+ if batch_shape_ is not ():
+ band_np = np.tile(band_np, batch_shape_ + (1, 1))
+ with self.test_session(use_gpu=False):
band = array_ops.matrix_band_part(batch_mat, lower, upper)
self.assertAllEqual(band_np, band.eval())
return Test
-class MatrixBandPartGradTest(test.TestCase):
+class MatrixBandPartGradTest(test_lib.TestCase):
pass # Filled in below
@@ -59,7 +70,7 @@ def _GetMatrixBandPartGradTest(dtype_, batch_shape_, shape_):
def Test(self):
shape = batch_shape_ + shape_
x = constant_op.constant(np.random.rand(*shape), dtype=dtype_)
- with self.test_session(use_gpu=True):
+ with self.test_session(use_gpu=False):
for lower in -1, 0, 1, shape_[-2] - 1:
for upper in -1, 0, 1, shape_[-1] - 1:
y = array_ops.matrix_band_part(x, lower, upper)
@@ -70,18 +81,77 @@ def _GetMatrixBandPartGradTest(dtype_, batch_shape_, shape_):
return Test
-if __name__ == '__main__':
- for dtype in (
- np.int32, np.int64, np.float32, np.float64, np.complex64, np.complex128):
+class MatrixBandPartBenchmark(test_lib.Benchmark):
+
+ shapes = [
+ (10, 16, 16),
+ (10, 101, 101),
+ (10, 256, 256),
+ (10, 1000, 1000),
+ (10, 1024, 1024),
+ (10, 2048, 2048),
+ (10, 10, 4, 4),
+ (10, 10, 10, 10),
+ (10, 10, 16, 16),
+ (10, 10, 101, 101),
+ (10, 10, 256, 256),
+ (10, 10, 1000, 1000),
+ (10, 10, 1024, 1024),
+ (10, 10, 2048, 2048),
+ ]
+
+ def benchmarkMatrixBandPartOp(self):
+ for shape_ in self.shapes:
+ for limits in (-1, -1), (-1, 0), (0, -1), (2, 2):
+ with ops.Graph().as_default(), \
+ session.Session() as sess, \
+ ops.device("/cpu:0"):
+ matrix = variables.Variable(array_ops.ones(shape_))
+ band = array_ops.matrix_band_part(matrix, limits[0], limits[1])
+ variables.global_variables_initializer().run()
+ self.run_op_benchmark(
+ sess,
+ control_flow_ops.group(band),
+ min_iters=10,
+ name="matrix_band_part_cpu_{shape}_{limits}".format(
+ shape=shape_, limits=limits))
+
+ if test_lib.is_gpu_available(True):
+ with ops.Graph().as_default(), \
+ session.Session() as sess, \
+ ops.device("/gpu:0"):
+ matrix = variables.Variable(array_ops.ones(shape_))
+ band = array_ops.matrix_band_part(matrix, limits[0], limits[1])
+ variables.global_variables_initializer().run()
+ self.run_op_benchmark(
+ sess,
+ control_flow_ops.group(band),
+ min_iters=10,
+ name="matrix_band_part_gpu_{shape}_{limits}".format(
+ shape=shape_, limits=limits))
+
+
+if __name__ == "__main__":
+ dtypes = (np.bool, np.int32, np.int64, np.float32, np.float64, np.complex64,
+ np.complex128)
+ for dtype in dtypes:
for batch_shape in ((), (2,), (1, 3, 2)):
for rows in 1, 2, 7:
for cols in 1, 2, 7:
shape = (rows, cols)
- name = '%s_%s' % (dtype.__name__, '_'.join(map(str, shape)))
- setattr(MatrixBandPartTest, 'testMatrixBandPart_' + name,
- _GetMatrixBandPartTest(dtype, batch_shape, shape))
- if dtype == np.float32 or dtype == np.float64:
- setattr(MatrixBandPartGradTest, 'testMatrixBandPartGrad_' + name,
- _GetMatrixBandPartGradTest(dtype, batch_shape, shape))
-
- test.main()
+ name = "%s_%s" % (dtype.__name__,
+ "_".join(map(str, batch_shape + shape)))
+ _AddTest(MatrixBandPartTest, "MatrixBandPart", name,
+ _GetMatrixBandPartTest(dtype, batch_shape, shape))
+
+ for dtype in (np.float32, np.float64):
+ for batch_shape in ((), (2,)):
+ for rows in 1, 2, 7:
+ for cols in 1, 2, 7:
+ shape = (rows, cols)
+ name = "%s_%s" % (dtype.__name__,
+ "_".join(map(str, batch_shape + shape)))
+ _AddTest(MatrixBandPartGradTest, "MatrixBandPartGrad", name,
+ _GetMatrixBandPartGradTest(dtype, batch_shape, shape))
+
+ test_lib.main()