aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/kernels/cholesky_op.cc15
-rw-r--r--tensorflow/core/kernels/matrix_band_part_op.cc5
-rw-r--r--tensorflow/core/kernels/matrix_band_part_op_gpu.cu.cc5
-rw-r--r--tensorflow/core/kernels/matrix_inverse_op.cc13
-rw-r--r--tensorflow/core/kernels/matrix_set_diag_op.cc71
-rw-r--r--tensorflow/core/kernels/matrix_set_diag_op.h68
-rw-r--r--tensorflow/core/kernels/matrix_set_diag_op_gpu.cu.cc70
7 files changed, 121 insertions, 126 deletions
diff --git a/tensorflow/core/kernels/cholesky_op.cc b/tensorflow/core/kernels/cholesky_op.cc
index 8b401a565b..bcd42dc8d7 100644
--- a/tensorflow/core/kernels/cholesky_op.cc
+++ b/tensorflow/core/kernels/cholesky_op.cc
@@ -112,6 +112,14 @@ class CholeskyOpGpu : public AsyncOpKernel {
input.dim_size(ndims - 2), " != ", n),
done);
+ if (input.NumElements() == 0) {
+ // If X is an empty matrix (0 rows, 0 col), X * X' == X.
+ // Therefore, we return X.
+ context->set_output(0, input);
+ done();
+ return;
+ }
+
// Allocate output.
// TODO(rmlarsen): Convert to std::make_unique when available.
std::unique_ptr<CudaSolver> solver(new CudaSolver(context));
@@ -121,13 +129,6 @@ class CholeskyOpGpu : public AsyncOpKernel {
{0}, 0, input.shape(), &output),
done);
- if (n == 0) {
- // If X is an empty matrix (0 rows, 0 col), X * X' == X.
- // Therefore, we return X.
- done();
- return;
- }
-
// Copy the lower triangular part of the input matrices to the output and
// set the strictly upper triangular part to zero. We use a pre-existing
// kernel MatrixBandPart to do this for all matrices in the batch at once,
diff --git a/tensorflow/core/kernels/matrix_band_part_op.cc b/tensorflow/core/kernels/matrix_band_part_op.cc
index e5f9086dba..d7fff4bb0c 100644
--- a/tensorflow/core/kernels/matrix_band_part_op.cc
+++ b/tensorflow/core/kernels/matrix_band_part_op.cc
@@ -80,8 +80,9 @@ class MatrixBandPartOp : public OpKernel {
input_reshaped.dimension(2),
") got: ", num_upper));
- if ((num_lower < 0 || num_lower == input_reshaped.dimension(1)) &&
- (num_upper < 0 || num_upper == input_reshaped.dimension(2))) {
+ if (input.NumElements() == 0 ||
+ ((num_lower < 0 || num_lower == input_reshaped.dimension(1)) &&
+ (num_upper < 0 || num_upper == input_reshaped.dimension(2)))) {
// This is a no-op.
context->set_output(0, input);
return;
diff --git a/tensorflow/core/kernels/matrix_band_part_op_gpu.cu.cc b/tensorflow/core/kernels/matrix_band_part_op_gpu.cu.cc
index 41b2f5c0ef..628d22b458 100644
--- a/tensorflow/core/kernels/matrix_band_part_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/matrix_band_part_op_gpu.cu.cc
@@ -54,17 +54,14 @@ struct MatrixBandPartFunctor<GPUDevice, Scalar> {
int num_lower_diags, int num_upper_diags,
typename TTypes<Scalar, 3>::ConstTensor input,
typename TTypes<Scalar, 3>::Tensor output) {
- using CudaType = typename CUDAComplexT<Scalar>::type;
const int batch_size = input.dimension(0);
const int m = input.dimension(1);
const int n = input.dimension(2);
- const CudaType* input_ptr = reinterpret_cast<const CudaType*>(input.data());
- CudaType* output_ptr = reinterpret_cast<CudaType*>(output.data());
CudaLaunchConfig config = GetCudaLaunchConfig(batch_size * m * n, device);
MatrixBandPartKernel<<<config.block_count, config.thread_per_block, 0,
device.stream()>>>(
config.virtual_thread_count, batch_size, m, n, num_lower_diags,
- num_upper_diags, input_ptr, output_ptr);
+ num_upper_diags, input.data(), output.data());
}
};
diff --git a/tensorflow/core/kernels/matrix_inverse_op.cc b/tensorflow/core/kernels/matrix_inverse_op.cc
index a152b5cbee..832e508bb7 100644
--- a/tensorflow/core/kernels/matrix_inverse_op.cc
+++ b/tensorflow/core/kernels/matrix_inverse_op.cc
@@ -109,6 +109,13 @@ class MatrixInverseOpGpu : public AsyncOpKernel {
input.dim_size(ndims - 2), " != ", n),
done);
+ // By definition, an empty matrix's inverse is an empty matrix.
+ if (input.NumElements() == 0) {
+ context->set_output(0, input);
+ done();
+ return;
+ }
+
// Allocate output.
Tensor* output;
OP_REQUIRES_OK_ASYNC(context,
@@ -116,12 +123,6 @@ class MatrixInverseOpGpu : public AsyncOpKernel {
{0}, 0, input.shape(), &output),
done);
- // By definition, an empty matrix's inverse is an empty matrix.
- if (input.NumElements() == 0) {
- done();
- return;
- }
-
// TODO(rmlarsen): Convert to std::make_unique when available.
std::unique_ptr<CudaSolver> solver(new CudaSolver(context));
diff --git a/tensorflow/core/kernels/matrix_set_diag_op.cc b/tensorflow/core/kernels/matrix_set_diag_op.cc
index 9573c4f8d1..9dd665392b 100644
--- a/tensorflow/core/kernels/matrix_set_diag_op.cc
+++ b/tensorflow/core/kernels/matrix_set_diag_op.cc
@@ -23,8 +23,6 @@ limitations under the License.
#include "tensorflow/core/kernels/matrix_set_diag_op.h"
-#include <memory>
-#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -32,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
@@ -73,22 +72,21 @@ class MatrixSetDiagOp : public OpKernel {
input_shape.DebugString(), " and diagonal shape: ",
diag_shape.DebugString()));
+ if (input.NumElements() == 0) {
+ // This is a no-op.
+ context->set_output(0, input);
+ return;
+ }
+
auto input_reshaped = input.flat_inner_dims<T, 3>();
auto diag_reshaped = diag.flat_inner_dims<T, 2>();
-
Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
{0}, 0, input_shape, &output));
auto output_reshaped = output->flat_inner_dims<T, 3>();
- Tensor scratch_tensor;
- OP_REQUIRES_OK(context,
- context->allocate_temp(DataTypeToEnum<T>::value,
- TensorShape({}), &scratch_tensor));
- auto scratch = scratch_tensor.scalar<T>();
-
- functor::MatrixSetDiag<Device, T>::Compute(context->eigen_device<Device>(),
- input_reshaped, diag_reshaped,
- scratch, output_reshaped);
+ functor::MatrixSetDiag<Device, T>::Compute(
+ context, context->eigen_device<Device>(), input_reshaped, diag_reshaped,
+ output_reshaped);
}
private:
@@ -116,32 +114,25 @@ namespace functor {
// Implementation of the functor specialization for CPU.
template <typename T>
struct MatrixSetDiag<CPUDevice, T> {
- static void Compute(const CPUDevice& d,
+ static void Compute(OpKernelContext* context, const CPUDevice& device,
typename TTypes<T, 3>::ConstTensor input,
typename TTypes<T, 2>::ConstTensor diag,
- typename TTypes<T>::Scalar scratch,
typename TTypes<T, 3>::Tensor output) {
- output.device(d) = input;
- for (int64 r = 0; r < output.dimension(0); ++r) {
- for (int64 d = 0; d < diag.dimension(1); ++d) {
- output(r, d, d) = diag(r, d);
- }
+ if (input.data() != output.data()) {
+ output.device(device) = input;
}
- }
-};
-
-template <>
-struct MatrixSetDiag<CPUDevice, bool> {
- static void Compute(const CPUDevice& d, TTypes<bool, 3>::ConstTensor input,
- TTypes<bool, 2>::ConstTensor diag,
- TTypes<bool>::Scalar scratch,
- TTypes<bool, 3>::Tensor output) {
- output.device(d) = input;
- for (int64 r = 0; r < output.dimension(0); ++r) {
- for (int64 d = 0; d < diag.dimension(1); ++d) {
- output(r, d, d) = diag(r, d);
+ auto compute_shard = [&output, &diag](int64 begin, int64 end) {
+ for (int64 batch = begin; batch < end; ++batch) {
+ for (int64 col = 0; col < diag.dimension(1); ++col) {
+ output(batch, col, col) = diag(batch, col);
+ }
}
- }
+ };
+ auto thread_pool =
+ context->device()->tensorflow_cpu_worker_threads()->workers;
+ int64 cost_per_batch = 10 * output.dimension(1); // Heuristic.
+ thread_pool->ParallelFor(output.dimension(0), cost_per_batch,
+ std::move(compute_shard));
}
};
@@ -151,13 +142,13 @@ struct MatrixSetDiag<CPUDevice, bool> {
// Forward declarations of the functor specializations for GPU.
namespace functor {
-#define DECLARE_GPU_SPEC(T) \
- template <> \
- void MatrixSetDiag<GPUDevice, T>::Compute( \
- const GPUDevice& d, typename TTypes<T, 3>::ConstTensor input, \
- typename TTypes<T, 2>::ConstTensor diag, \
- typename TTypes<T>::Scalar scratch, \
- typename TTypes<T, 3>::Tensor output); \
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void MatrixSetDiag<GPUDevice, T>::Compute( \
+ OpKernelContext* context, const GPUDevice& d, \
+ typename TTypes<T, 3>::ConstTensor input, \
+ typename TTypes<T, 2>::ConstTensor diag, \
+ typename TTypes<T, 3>::Tensor output); \
extern template struct MatrixSetDiag<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
diff --git a/tensorflow/core/kernels/matrix_set_diag_op.h b/tensorflow/core/kernels/matrix_set_diag_op.h
index 63e5650bf0..aeb144559f 100644
--- a/tensorflow/core/kernels/matrix_set_diag_op.h
+++ b/tensorflow/core/kernels/matrix_set_diag_op.h
@@ -16,80 +16,22 @@ limitations under the License.
#ifndef TENSORFLOW_KERNELS_MATRIX_SET_DIAG_OP_H_
#define TENSORFLOW_KERNELS_MATRIX_SET_DIAG_OP_H_
-// Generator definition for MatrixSetDiagOp, must be compilable by nvcc.
-
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
-
-namespace generator {
-
-template <typename T>
-class OverwriteDiagGenerator {
- public:
- EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
- OverwriteDiagGenerator(typename TTypes<T, 2>::ConstTensor diag,
- typename TTypes<T, 3>::Tensor output)
- : diag_(diag), output_(output) {}
-
- EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
- operator()(const Eigen::array<Eigen::DenseIndex, 2>& coords) const {
- Eigen::array<Eigen::DenseIndex, 3> diag_from_coords(
- {coords[0], coords[1], coords[1]});
-
- // This is the side effect we care about.
- output_(diag_from_coords) = diag_(coords);
-
- return T(0);
- }
-
- private:
- typename TTypes<T, 2>::ConstTensor diag_;
- mutable typename TTypes<T, 3>::Tensor output_;
-};
-
-} // namespace generator
-
namespace functor {
template <typename Device, typename T>
struct MatrixSetDiag {
- EIGEN_ALWAYS_INLINE static void Compute(
- const Device& d, typename TTypes<T, 3>::ConstTensor input,
- typename TTypes<T, 2>::ConstTensor diag,
- typename TTypes<T>::Scalar scratch,
- typename TTypes<T, 3>::Tensor output) {
- output.device(d) = input;
- generator::OverwriteDiagGenerator<T> generator(diag, output);
- // Use sum() to force the generation to aggregate to the scalar
- // output scratch. This in turn forces each element of the
- // generator to execute. The side effect of the execution is to
- // update the diagonal components of output with diag.
- scratch.device(d) = diag.generate(generator).sum();
- }
-};
-
-template <typename Device>
-struct MatrixSetDiag<Device, bool> {
- EIGEN_ALWAYS_INLINE static void Compute(const Device& d,
- TTypes<bool, 3>::ConstTensor input,
- TTypes<bool, 2>::ConstTensor diag,
- TTypes<bool>::Scalar scratch,
- TTypes<bool, 3>::Tensor output) {
- output.device(d) = input;
- generator::OverwriteDiagGenerator<bool> generator(diag, output);
- // Use all() to force the generation to aggregate to the scalar
- // output scratch. This in turn forces each element of the
- // generator to execute. The side effect of the execution is to
- // update the diagonal components of output with diag.
- scratch.device(d) = diag.generate(generator).all();
- }
+ static void Compute(OpKernelContext* context, const Device& d,
+ typename TTypes<T, 3>::ConstTensor input,
+ typename TTypes<T, 2>::ConstTensor diag,
+ typename TTypes<T, 3>::Tensor output);
};
} // namespace functor
-
} // namespace tensorflow
#endif // TENSORFLOW_KERNELS_MATRIX_SET_DIAG_OP_H_
diff --git a/tensorflow/core/kernels/matrix_set_diag_op_gpu.cu.cc b/tensorflow/core/kernels/matrix_set_diag_op_gpu.cu.cc
index 8e41ce5860..35037b8e14 100644
--- a/tensorflow/core/kernels/matrix_set_diag_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/matrix_set_diag_op_gpu.cu.cc
@@ -19,20 +19,82 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/matrix_set_diag_op.h"
+#include "tensorflow/core/util/cuda_kernel_helper.h"
namespace tensorflow {
+namespace functor {
typedef Eigen::GpuDevice GPUDevice;
-#define DEFINE_GPU_SPEC(T) \
- template class generator::OverwriteDiagGenerator<T>; \
- template struct functor::MatrixSetDiag<GPUDevice, T>;
+template <typename Scalar>
+__global__ void MatrixSetDiagKernel(const int num_threads, const int m,
+ const int n, const int minsize,
+ const Scalar* diag_ptr,
+ Scalar* output_ptr) {
+ CUDA_1D_KERNEL_LOOP(index, num_threads) {
+ const int batch = index / minsize;
+ const int col = index - batch * minsize;
+ const int out_index = batch * m * n + (n + 1) * col;
+ output_ptr[out_index] = diag_ptr[index];
+ }
+}
+
+template <typename Scalar>
+__global__ void MatrixCopyInputAndSetDiagKernel(
+ const int num_threads, const int m, const int n, const int minsize,
+ const Scalar* input_ptr, const Scalar* diag_ptr, Scalar* output_ptr) {
+ CUDA_1D_KERNEL_LOOP(index, num_threads) {
+ const int global_row = index / n;
+ const int col = index - global_row * n;
+ const int batch = global_row / m;
+ const int row = global_row - batch * m;
+ if (col == row) {
+ // Because col = index % n, and row = (index / n) % m,
+ // we know that col==row => col < minsize, so the following is safe:
+ output_ptr[index] = diag_ptr[batch * minsize + col];
+ } else {
+ output_ptr[index] = input_ptr[index];
+ }
+ }
+}
+
+template <typename Scalar>
+struct MatrixSetDiag<GPUDevice, Scalar> {
+ static void Compute(OpKernelContext* context, const GPUDevice& device,
+ typename TTypes<Scalar, 3>::ConstTensor input,
+ typename TTypes<Scalar, 2>::ConstTensor diag,
+ typename TTypes<Scalar, 3>::Tensor output) {
+ const int batch_size = input.dimension(0);
+ const int m = input.dimension(1);
+ const int n = input.dimension(2);
+ const int minsize = std::min(m, n);
+ CHECK_EQ(diag.dimension(1), minsize);
+ if (batch_size == 0 || minsize == 0) return;
+ if (input.data() == output.data()) {
+ CudaLaunchConfig config =
+ GetCudaLaunchConfig(batch_size * minsize, device);
+ MatrixSetDiagKernel<Scalar>
+ <<<config.block_count, config.thread_per_block, 0, device.stream()>>>(
+ config.virtual_thread_count, m, n, minsize, diag.data(),
+ output.data());
+ } else {
+ CudaLaunchConfig config = GetCudaLaunchConfig(batch_size * m * n, device);
+ MatrixCopyInputAndSetDiagKernel<Scalar>
+ <<<config.block_count, config.thread_per_block, 0, device.stream()>>>(
+ config.virtual_thread_count, m, n, minsize, input.data(),
+ diag.data(), output.data());
+ }
+ }
+};
+
+#define DEFINE_GPU_SPEC(T) template struct MatrixSetDiag<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC);
TF_CALL_bool(DEFINE_GPU_SPEC);
TF_CALL_complex64(DEFINE_GPU_SPEC);
TF_CALL_complex128(DEFINE_GPU_SPEC);
-} // end namespace tensorflow
+} // namespace functor
+} // namespace tensorflow
#endif // GOOGLE_CUDA