aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/matrix_set_diag_op.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-11 23:59:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-12 00:02:52 -0700
commitcec93f10dcf5d2e647a41bd6bf95357cf9d9169d (patch)
tree5fc58a2e8cbd9becdda1c6cbfad1045c35dd9b74 /tensorflow/core/kernels/matrix_set_diag_op.cc
parentd455bd6a851450657c702808d096f39583b949b5 (diff)
Optimized C++ and CUDA kernels for matrix_set_diag op. The new code is faster and more readable and avoids an issue with using the Eigen generator mechanism with GPUs on Windows.
PiperOrigin-RevId: 171924800
Diffstat (limited to 'tensorflow/core/kernels/matrix_set_diag_op.cc')
-rw-r--r--tensorflow/core/kernels/matrix_set_diag_op.cc71
1 files changed, 31 insertions, 40 deletions
diff --git a/tensorflow/core/kernels/matrix_set_diag_op.cc b/tensorflow/core/kernels/matrix_set_diag_op.cc
index 9573c4f8d1..9dd665392b 100644
--- a/tensorflow/core/kernels/matrix_set_diag_op.cc
+++ b/tensorflow/core/kernels/matrix_set_diag_op.cc
@@ -23,8 +23,6 @@ limitations under the License.
#include "tensorflow/core/kernels/matrix_set_diag_op.h"
-#include <memory>
-#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -32,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
@@ -73,22 +72,21 @@ class MatrixSetDiagOp : public OpKernel {
input_shape.DebugString(), " and diagonal shape: ",
diag_shape.DebugString()));
+ if (input.NumElements() == 0) {
+ // This is a no-op.
+ context->set_output(0, input);
+ return;
+ }
+
auto input_reshaped = input.flat_inner_dims<T, 3>();
auto diag_reshaped = diag.flat_inner_dims<T, 2>();
-
Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
{0}, 0, input_shape, &output));
auto output_reshaped = output->flat_inner_dims<T, 3>();
- Tensor scratch_tensor;
- OP_REQUIRES_OK(context,
- context->allocate_temp(DataTypeToEnum<T>::value,
- TensorShape({}), &scratch_tensor));
- auto scratch = scratch_tensor.scalar<T>();
-
- functor::MatrixSetDiag<Device, T>::Compute(context->eigen_device<Device>(),
- input_reshaped, diag_reshaped,
- scratch, output_reshaped);
+ functor::MatrixSetDiag<Device, T>::Compute(
+ context, context->eigen_device<Device>(), input_reshaped, diag_reshaped,
+ output_reshaped);
}
private:
@@ -116,32 +114,25 @@ namespace functor {
// Implementation of the functor specialization for CPU.
template <typename T>
struct MatrixSetDiag<CPUDevice, T> {
- static void Compute(const CPUDevice& d,
+ static void Compute(OpKernelContext* context, const CPUDevice& device,
typename TTypes<T, 3>::ConstTensor input,
typename TTypes<T, 2>::ConstTensor diag,
- typename TTypes<T>::Scalar scratch,
typename TTypes<T, 3>::Tensor output) {
- output.device(d) = input;
- for (int64 r = 0; r < output.dimension(0); ++r) {
- for (int64 d = 0; d < diag.dimension(1); ++d) {
- output(r, d, d) = diag(r, d);
- }
+ if (input.data() != output.data()) {
+ output.device(device) = input;
}
- }
-};
-
-template <>
-struct MatrixSetDiag<CPUDevice, bool> {
- static void Compute(const CPUDevice& d, TTypes<bool, 3>::ConstTensor input,
- TTypes<bool, 2>::ConstTensor diag,
- TTypes<bool>::Scalar scratch,
- TTypes<bool, 3>::Tensor output) {
- output.device(d) = input;
- for (int64 r = 0; r < output.dimension(0); ++r) {
- for (int64 d = 0; d < diag.dimension(1); ++d) {
- output(r, d, d) = diag(r, d);
+ auto compute_shard = [&output, &diag](int64 begin, int64 end) {
+ for (int64 batch = begin; batch < end; ++batch) {
+ for (int64 col = 0; col < diag.dimension(1); ++col) {
+ output(batch, col, col) = diag(batch, col);
+ }
}
- }
+ };
+ auto thread_pool =
+ context->device()->tensorflow_cpu_worker_threads()->workers;
+ int64 cost_per_batch = 10 * output.dimension(1); // Heuristic.
+ thread_pool->ParallelFor(output.dimension(0), cost_per_batch,
+ std::move(compute_shard));
}
};
@@ -151,13 +142,13 @@ struct MatrixSetDiag<CPUDevice, bool> {
// Forward declarations of the functor specializations for GPU.
namespace functor {
-#define DECLARE_GPU_SPEC(T) \
- template <> \
- void MatrixSetDiag<GPUDevice, T>::Compute( \
- const GPUDevice& d, typename TTypes<T, 3>::ConstTensor input, \
- typename TTypes<T, 2>::ConstTensor diag, \
- typename TTypes<T>::Scalar scratch, \
- typename TTypes<T, 3>::Tensor output); \
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void MatrixSetDiag<GPUDevice, T>::Compute( \
+ OpKernelContext* context, const GPUDevice& d, \
+ typename TTypes<T, 3>::ConstTensor input, \
+ typename TTypes<T, 2>::ConstTensor diag, \
+ typename TTypes<T, 3>::Tensor output); \
extern template struct MatrixSetDiag<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);