diff options
Diffstat (limited to 'tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc')
-rw-r--r-- | tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc | 143 |
1 files changed, 81 insertions, 62 deletions
diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc index 2bcc7f407d..30026f222a 100644 --- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc +++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc @@ -29,7 +29,7 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -template <typename Device, typename T> +template <typename Device, typename T, typename Tindices> class SparseTensorDenseMatMulOp : public OpKernel { public: explicit SparseTensorDenseMatMulOp(OpKernelConstruction* ctx) @@ -139,15 +139,14 @@ class SparseTensorDenseMatMulOp : public OpKernel { TensorShape({0}), &scratch)); } -#define MAYBE_ADJOINT(ADJ_A, ADJ_B) \ - if (adjoint_a_ == ADJ_A && adjoint_b_ == ADJ_B) { \ - Status functor_status = functor::SparseTensorDenseMatMulFunctor< \ - Device, T, ADJ_A, ADJ_B>::Compute(ctx->eigen_device<Device>(), \ - out->matrix<T>(), \ - a_indices->matrix<int64>(), \ - a_values->vec<T>(), b->matrix<T>(), \ - scratch.vec<T>()); \ - OP_REQUIRES_OK(ctx, functor_status); \ +#define MAYBE_ADJOINT(ADJ_A, ADJ_B) \ + if (adjoint_a_ == ADJ_A && adjoint_b_ == ADJ_B) { \ + Status functor_status = functor::SparseTensorDenseMatMulFunctor< \ + Device, T, Tindices, ADJ_A, \ + ADJ_B>::Compute(ctx->eigen_device<Device>(), out->matrix<T>(), \ + a_indices->matrix<Tindices>(), a_values->vec<T>(), \ + b->matrix<T>(), scratch.vec<T>()); \ + OP_REQUIRES_OK(ctx, functor_status); \ } MAYBE_ADJOINT(false, false); @@ -163,53 +162,73 @@ class SparseTensorDenseMatMulOp : public OpKernel { bool adjoint_b_; }; -#define REGISTER_CPU(T) \ - REGISTER_KERNEL_BUILDER(Name("SparseTensorDenseMatMul") \ - .Device(DEVICE_CPU) \ - .TypeConstraint<T>("T") \ - .HostMemory("a_shape"), \ - SparseTensorDenseMatMulOp<CPUDevice, T>); - -REGISTER_CPU(float); -REGISTER_CPU(double); -REGISTER_CPU(int32); -REGISTER_CPU(complex64); -REGISTER_CPU(complex128); +#define REGISTER_CPU(TypeT, TypeIndex) \ + REGISTER_KERNEL_BUILDER( \ + Name("SparseTensorDenseMatMul") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<TypeT>("T") \ + .TypeConstraint<TypeIndex>("Tindices") \ + .HostMemory("a_shape"), \ + SparseTensorDenseMatMulOp<CPUDevice, TypeT, TypeIndex>); + +#define REGISTER_KERNELS_CPU(T) \ + REGISTER_CPU(T, int64); \ + REGISTER_CPU(T, int32) + +REGISTER_KERNELS_CPU(float); +REGISTER_KERNELS_CPU(double); +REGISTER_KERNELS_CPU(int32); +REGISTER_KERNELS_CPU(complex64); +REGISTER_KERNELS_CPU(complex128); #if GOOGLE_CUDA namespace functor { -#define DECLARE_GPU_SPEC(T, ADJ_A, ADJ_B) \ - template <> \ - Status SparseTensorDenseMatMulFunctor<GPUDevice, T, ADJ_A, ADJ_B>::Compute( \ - const GPUDevice& d, typename TTypes<T>::Matrix out, \ - TTypes<int64>::ConstMatrix a_indices, \ - typename TTypes<T>::ConstVec a_values, \ - typename TTypes<T>::ConstMatrix b, typename TTypes<T>::Vec scratch); \ - extern template struct SparseTensorDenseMatMulFunctor<GPUDevice, T, ADJ_A, \ - ADJ_B>; - -#define DECLARE_ADJOINT_GPU_SPEC(T) \ - DECLARE_GPU_SPEC(T, false, false) \ - DECLARE_GPU_SPEC(T, false, true) \ - DECLARE_GPU_SPEC(T, true, false) \ - DECLARE_GPU_SPEC(T, true, true) +#define DECLARE_GPU_SPEC(T, Tindices, ADJ_A, ADJ_B) \ + template <> \ + Status SparseTensorDenseMatMulFunctor< \ + GPUDevice, T, Tindices, ADJ_A, \ + ADJ_B>::Compute(const GPUDevice& d, typename TTypes<T>::Matrix out, \ + typename TTypes<Tindices>::ConstMatrix a_indices, \ + typename TTypes<T>::ConstVec a_values, \ + typename TTypes<T>::ConstMatrix b, \ + typename TTypes<T>::Vec scratch); \ + extern template struct SparseTensorDenseMatMulFunctor< \ + GPUDevice, T, Tindices, ADJ_A, ADJ_B>; + +#define REGISTER_GPU_SPEC(T, ADJ_A, ADJ_B) \ + DECLARE_GPU_SPEC(T, int32, ADJ_A, ADJ_B); \ + DECLARE_GPU_SPEC(T, int64, ADJ_A, ADJ_B) + +#define DECLARE_ADJOINT_GPU_SPEC(T) \ + REGISTER_GPU_SPEC(T, false, false) \ + REGISTER_GPU_SPEC(T, false, true) \ + REGISTER_GPU_SPEC(T, true, false) \ + REGISTER_GPU_SPEC(T, true, true) DECLARE_ADJOINT_GPU_SPEC(float); #undef DECLARE_ADJOINT_GPU_SPEC #undef DECLARE_GPU_SPEC +#undef REGISTER_GPU_SPEC } // namespace functor -#define REGISTER_GPU(T) \ - REGISTER_KERNEL_BUILDER(Name("SparseTensorDenseMatMul") \ - .Device(DEVICE_GPU) \ - .TypeConstraint<T>("T") \ - .HostMemory("a_shape"), \ - SparseTensorDenseMatMulOp<GPUDevice, T>); +#define REGISTER_GPU(TypeT, TypeIndex) \ + REGISTER_KERNEL_BUILDER( \ + Name("SparseTensorDenseMatMul") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<TypeT>("T") \ + .TypeConstraint<TypeIndex>("Tindices") \ + .HostMemory("a_shape"), \ + SparseTensorDenseMatMulOp<GPUDevice, TypeT, TypeIndex>); + +#define REGISTER_KERNELS_GPU(T) \ + REGISTER_GPU(T, int64); \ + REGISTER_GPU(T, int32) -REGISTER_GPU(float); +REGISTER_KERNELS_GPU(float); #undef REGISTER_GPU +#undef REGISTER_KERNELS_GPU #endif // GOOGLE_CUDA namespace functor { @@ -228,13 +247,13 @@ Status MOutOfBoundsError(int64 m, std::size_t i, int lhs_index_a, } } // namespace -template <typename T, bool ADJ_A, bool ADJ_B> -struct SparseTensorDenseMatMulFunctor<CPUDevice, T, ADJ_A, ADJ_B> { +template <typename T, typename Tindices, bool ADJ_A, bool ADJ_B> +struct SparseTensorDenseMatMulFunctor<CPUDevice, T, Tindices, ADJ_A, ADJ_B> { // Vectorize certain operations above this size. static const std::size_t kNumVectorize = 32; static Status Compute(const CPUDevice& d, typename TTypes<T>::Matrix out, - TTypes<int64>::ConstMatrix a_indices, + typename TTypes<Tindices>::ConstMatrix a_indices, typename TTypes<T>::ConstVec a_values, typename TTypes<T>::ConstMatrix b, typename TTypes<T>::Vec scratch) { @@ -255,8 +274,8 @@ struct SparseTensorDenseMatMulFunctor<CPUDevice, T, ADJ_A, ADJ_B> { auto maybe_adjoint_b = MaybeAdjoint<decltype(b), ADJ_B>(b); for (std::size_t i = 0; i < nnz; ++i) { - const int64 m = internal::SubtleMustCopy(a_indices(i, lhs_index_a)); - const int64 k = internal::SubtleMustCopy(a_indices(i, rhs_index_a)); + const Tindices m = internal::SubtleMustCopy(a_indices(i, lhs_index_a)); + const Tindices k = internal::SubtleMustCopy(a_indices(i, rhs_index_a)); if (!FastBoundsCheck(k, lhs_right)) { return KOutOfBoundsError(k, i, rhs_index_a, lhs_right); } @@ -273,19 +292,19 @@ struct SparseTensorDenseMatMulFunctor<CPUDevice, T, ADJ_A, ADJ_B> { // Vectorization via Eigen. const int b_chip_index = ADJ_B ? 1 : 0; -#define LOOP_NNZ(b_passed) \ - for (std::size_t i = 0; i < nnz; ++i) { \ - const int64 m = internal::SubtleMustCopy(a_indices(i, lhs_index_a)); \ - const int64 k = internal::SubtleMustCopy(a_indices(i, rhs_index_a)); \ - const T a_value = (ADJ_A) ? MaybeConj(a_values(i)) : a_values(i); \ - if (!FastBoundsCheck(k, lhs_right)) { \ - return KOutOfBoundsError(k, i, rhs_index_a, lhs_right); \ - } \ - if (!FastBoundsCheck(m, out.dimension(0))) { \ - return MOutOfBoundsError(m, i, lhs_index_a, out.dimension(0)); \ - } \ - out.template chip<0>(m) += \ - b_passed.template chip<b_chip_index>(k) * a_value; \ +#define LOOP_NNZ(b_passed) \ + for (std::size_t i = 0; i < nnz; ++i) { \ + const Tindices m = internal::SubtleMustCopy(a_indices(i, lhs_index_a)); \ + const Tindices k = internal::SubtleMustCopy(a_indices(i, rhs_index_a)); \ + const T a_value = (ADJ_A) ? MaybeConj(a_values(i)) : a_values(i); \ + if (!FastBoundsCheck(k, lhs_right)) { \ + return KOutOfBoundsError(k, i, rhs_index_a, lhs_right); \ + } \ + if (!FastBoundsCheck(m, out.dimension(0))) { \ + return MOutOfBoundsError(m, i, lhs_index_a, out.dimension(0)); \ + } \ + out.template chip<0>(m) += \ + b_passed.template chip<b_chip_index>(k) * a_value; \ } if (ADJ_B) { |