diff options
author | 2017-10-11 23:59:11 -0700 | |
---|---|---|
committer | 2017-10-12 00:02:52 -0700 | |
commit | cec93f10dcf5d2e647a41bd6bf95357cf9d9169d (patch) | |
tree | 5fc58a2e8cbd9becdda1c6cbfad1045c35dd9b74 /tensorflow/core/kernels/matrix_set_diag_op.cc | |
parent | d455bd6a851450657c702808d096f39583b949b5 (diff) |
Optimized C++ and CUDA kernels for matrix_set_diag op. The new code is faster and more readable and avoids an issue with using the Eigen generator mechanism with GPUs on Windows.
PiperOrigin-RevId: 171924800
Diffstat (limited to 'tensorflow/core/kernels/matrix_set_diag_op.cc')
-rw-r--r-- | tensorflow/core/kernels/matrix_set_diag_op.cc | 71 |
1 files changed, 31 insertions, 40 deletions
diff --git a/tensorflow/core/kernels/matrix_set_diag_op.cc b/tensorflow/core/kernels/matrix_set_diag_op.cc index 9573c4f8d1..9dd665392b 100644 --- a/tensorflow/core/kernels/matrix_set_diag_op.cc +++ b/tensorflow/core/kernels/matrix_set_diag_op.cc @@ -23,8 +23,6 @@ limitations under the License. #include "tensorflow/core/kernels/matrix_set_diag_op.h" -#include <memory> -#include <vector> #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -32,6 +30,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" @@ -73,22 +72,21 @@ class MatrixSetDiagOp : public OpKernel { input_shape.DebugString(), " and diagonal shape: ", diag_shape.DebugString())); + if (input.NumElements() == 0) { + // This is a no-op. + context->set_output(0, input); + return; + } + auto input_reshaped = input.flat_inner_dims<T, 3>(); auto diag_reshaped = diag.flat_inner_dims<T, 2>(); - Tensor* output = nullptr; OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( {0}, 0, input_shape, &output)); auto output_reshaped = output->flat_inner_dims<T, 3>(); - Tensor scratch_tensor; - OP_REQUIRES_OK(context, - context->allocate_temp(DataTypeToEnum<T>::value, - TensorShape({}), &scratch_tensor)); - auto scratch = scratch_tensor.scalar<T>(); - - functor::MatrixSetDiag<Device, T>::Compute(context->eigen_device<Device>(), - input_reshaped, diag_reshaped, - scratch, output_reshaped); + functor::MatrixSetDiag<Device, T>::Compute( + context, context->eigen_device<Device>(), input_reshaped, diag_reshaped, + output_reshaped); } private: @@ -116,32 +114,25 @@ namespace functor { // Implementation of the functor specialization for CPU. template <typename T> struct MatrixSetDiag<CPUDevice, T> { - static void Compute(const CPUDevice& d, + static void Compute(OpKernelContext* context, const CPUDevice& device, typename TTypes<T, 3>::ConstTensor input, typename TTypes<T, 2>::ConstTensor diag, - typename TTypes<T>::Scalar scratch, typename TTypes<T, 3>::Tensor output) { - output.device(d) = input; - for (int64 r = 0; r < output.dimension(0); ++r) { - for (int64 d = 0; d < diag.dimension(1); ++d) { - output(r, d, d) = diag(r, d); - } + if (input.data() != output.data()) { + output.device(device) = input; } - } -}; - -template <> -struct MatrixSetDiag<CPUDevice, bool> { - static void Compute(const CPUDevice& d, TTypes<bool, 3>::ConstTensor input, - TTypes<bool, 2>::ConstTensor diag, - TTypes<bool>::Scalar scratch, - TTypes<bool, 3>::Tensor output) { - output.device(d) = input; - for (int64 r = 0; r < output.dimension(0); ++r) { - for (int64 d = 0; d < diag.dimension(1); ++d) { - output(r, d, d) = diag(r, d); + auto compute_shard = [&output, &diag](int64 begin, int64 end) { + for (int64 batch = begin; batch < end; ++batch) { + for (int64 col = 0; col < diag.dimension(1); ++col) { + output(batch, col, col) = diag(batch, col); + } } - } + }; + auto thread_pool = + context->device()->tensorflow_cpu_worker_threads()->workers; + int64 cost_per_batch = 10 * output.dimension(1); // Heuristic. + thread_pool->ParallelFor(output.dimension(0), cost_per_batch, + std::move(compute_shard)); } }; @@ -151,13 +142,13 @@ struct MatrixSetDiag<CPUDevice, bool> { // Forward declarations of the functor specializations for GPU. namespace functor { -#define DECLARE_GPU_SPEC(T) \ - template <> \ - void MatrixSetDiag<GPUDevice, T>::Compute( \ - const GPUDevice& d, typename TTypes<T, 3>::ConstTensor input, \ - typename TTypes<T, 2>::ConstTensor diag, \ - typename TTypes<T>::Scalar scratch, \ - typename TTypes<T, 3>::Tensor output); \ +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void MatrixSetDiag<GPUDevice, T>::Compute( \ + OpKernelContext* context, const GPUDevice& d, \ + typename TTypes<T, 3>::ConstTensor input, \ + typename TTypes<T, 2>::ConstTensor diag, \ + typename TTypes<T, 3>::Tensor output); \ extern template struct MatrixSetDiag<GPUDevice, T>; TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); |