aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-05-04 15:35:27 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-04 16:50:16 -0700
commitffc5e6fbf68b5229b68cfe96bbcaf58619277c06 (patch)
tree6e42388824efe17c6f1bc4678adf99446e7cdf71 /tensorflow
parent7ff483746404a3ca7ed66ae4271b21bcb07082ee (diff)
[TF] optimization: SparseTensorDenseMatMul GPU kernel rewritten in pure cuda.
Also added an additional GPU int32::max check that was missing. Performance seems to be between 1x-10x faster on average. The likely culprit on CPU slowdown was probably the unnecessary temp allocation for scratch space. Performance on a k40, compiled -c opt --config cuda --copt=-mavx: **BEFORE** Matrix sizes: A sparse [m, k] with % nonzero values between 1% and 80% B dense [k, n] % nnz n gpu m k dt(dense) dt(sparse) dt(sparse)/dt(dense) 0.01 50 True 100 100 0.000319954 0.000275495 0.861045 0.01 50 True 100 1000 0.000469565 0.000290895 0.619497 0.01 50 True 1000 100 0.000572815 0.000271131 0.473331 0.01 50 True 1000 1000 0.00133119 0.00042006 0.315554 0.01 50 False 100 100 0.00034191 0.000289171 0.845751 0.01 50 False 100 1000 0.0004796 0.00028483 0.593891 0.01 50 False 1000 100 0.000632371 0.000300461 0.475134 0.01 50 False 1000 1000 0.00134726 0.000576285 0.427746 0.01 100 True 100 100 0.000353755 0.00027729 0.783849 0.01 100 True 100 1000 0.000536649 0.00028337 0.528036 0.01 100 True 1000 100 0.000661941 0.00027933 0.421987 0.01 100 True 1000 1000 0.0014109 0.0006698 0.474732 0.01 100 False 100 100 0.00039546 0.00030159 0.762631 0.01 100 False 100 1000 0.00054909 0.00027276 0.49675 0.01 100 False 1000 100 0.000631344 0.00028231 0.447157 0.01 100 False 1000 1000 0.00141789 0.000657049 0.463398 0.2 50 True 100 100 0.00033689 0.000280155 0.831591 0.2 50 True 100 1000 0.000563495 0.00064159 1.13859 0.2 50 True 1000 100 0.00058635 0.00067611 1.15308 0.2 50 True 1000 1000 0.00153552 0.00486242 3.16662 0.2 50 False 100 100 0.000333545 0.000267555 0.802154 0.2 50 False 100 1000 0.000544 0.00066272 1.21824 0.2 50 False 1000 100 0.00058253 0.000670955 1.15179 0.2 50 False 1000 1000 0.00153017 0.00480928 3.14298 0.2 100 True 100 100 0.00036919 0.000288659 0.781872 0.2 100 True 100 1000 0.00067063 0.00110059 1.64113 0.2 100 True 1000 100 0.00066443 0.00108547 1.63369 0.2 100 True 1000 1000 0.00180991 0.00961579 5.31286 0.2 100 False 100 100 0.00040061 0.000325365 0.812174 0.2 100 False 100 1000 0.00066774 0.00111843 1.67494 0.2 100 False 1000 100 0.000696205 0.00108078 1.55239 0.2 100 False 1000 1000 0.00179788 0.00960569 5.34278 0.5 50 True 100 100 0.00034819 0.00033425 0.959963 0.5 50 True 100 1000 0.00075176 0.00134084 1.78359 0.5 50 True 1000 100 0.000642445 0.00133641 2.08019 0.5 50 True 1000 1000 0.00233791 0.0124282 5.31597 0.5 50 False 100 100 0.000345069 0.000334586 0.96962 0.5 50 False 100 1000 0.00071701 0.00135879 1.89508 0.5 50 False 1000 100 0.000632119 0.00134036 2.12043 0.5 50 False 1000 1000 0.00240216 0.0126202 5.25368 0.5 100 True 100 100 0.000393934 0.00040344 1.02413 0.5 100 True 100 1000 0.000957675 0.002709 2.82873 0.5 100 True 1000 100 0.000756125 0.00242428 3.20619 0.5 100 True 1000 1000 0.00298202 0.0241416 8.09572 0.5 100 False 100 100 0.000395606 0.000433675 1.09623 0.5 100 False 100 1000 0.000963565 0.00248293 2.57682 0.5 100 False 1000 100 0.00079523 0.0024281 3.05333 0.5 100 False 1000 1000 0.00299668 0.0242615 8.09614 0.8 50 True 100 100 0.00036806 0.00040923 1.11186 0.8 50 True 100 1000 0.00091419 0.00207383 2.26848 0.8 50 True 1000 100 0.000684329 0.00196612 2.87307 0.8 50 True 1000 1000 0.00302433 0.0199798 6.60637 0.8 50 False 100 100 0.000368149 0.000615025 1.67058 0.8 50 False 100 1000 0.0008786 0.00205821 2.3426 0.8 50 False 1000 100 0.00067889 0.00195498 2.87967 0.8 50 False 1000 1000 0.00290009 0.0191242 6.59434 0.8 100 True 100 100 0.000452549 0.00063767 1.40906 0.8 100 True 100 1000 0.00126929 0.00391422 3.08378 0.8 100 True 1000 100 0.000919235 0.00386167 4.20096 0.8 100 True 1000 1000 0.00423295 0.0431824 10.2015 0.8 100 False 100 100 0.000428261 0.000626891 1.46381 0.8 100 False 100 1000 0.00120801 0.00395877 3.27711 0.8 100 False 1000 100 0.00080466 0.00385143 4.78641 0.8 100 False 1000 1000 0.00370808 0.0403527 10.8824 **AFTER** Matrix sizes: A sparse [m, k] with % nonzero values between 1% and 80% B dense [k, n] % nnz n gpu m k dt(dense) dt(sparse) dt(sparse)/dt(dense) 0.01 50 True 100 100 0.000312485 0.00020528 0.656927 0.01 50 True 100 1000 0.0004655 0.00020095 0.431686 0.01 50 True 1000 100 0.000567449 0.000203935 0.359389 0.01 50 True 1000 1000 0.00132323 0.00027171 0.205339 0.01 50 False 100 100 0.000319945 0.000197511 0.617328 0.01 50 False 100 1000 0.000466419 0.000210185 0.450635 0.01 50 False 1000 100 0.0005581 0.000199865 0.358117 0.01 50 False 1000 1000 0.00129479 0.000451496 0.348702 0.01 100 True 100 100 0.000364131 0.000196835 0.540561 0.01 100 True 100 1000 0.00053398 0.000206494 0.386708 0.01 100 True 1000 100 0.00062722 0.000203185 0.323946 0.01 100 True 1000 1000 0.00138674 0.000335904 0.242227 0.01 100 False 100 100 0.000361339 0.000195 0.53966 0.01 100 False 100 1000 0.000531831 0.000207155 0.389513 0.01 100 False 1000 100 0.00062245 0.000197015 0.316515 0.01 100 False 1000 1000 0.0014007 0.000328825 0.234757 0.2 50 True 100 100 0.00033185 0.000262895 0.792209 0.2 50 True 100 1000 0.00054391 0.000586189 1.07773 0.2 50 True 1000 100 0.000581805 0.000531535 0.913597 0.2 50 True 1000 1000 0.00153913 0.00142783 0.927687 0.2 50 False 100 100 0.00033572 0.000266831 0.794803 0.2 50 False 100 1000 0.000534315 0.000585151 1.09514 0.2 50 False 1000 100 0.000580961 0.00033344 0.573947 0.2 50 False 1000 1000 0.0015055 0.00143968 0.956284 0.2 100 True 100 100 0.000371666 0.00026337 0.708621 0.2 100 True 100 1000 0.000667235 0.00056811 0.851439 0.2 100 True 1000 100 0.000671356 0.000400575 0.596666 0.2 100 True 1000 1000 0.00178568 0.00250393 1.40222 0.2 100 False 100 100 0.000370425 0.000254935 0.688223 0.2 100 False 100 1000 0.000661175 0.000601134 0.909191 0.2 100 False 1000 100 0.0006944 0.00039817 0.573401 0.2 100 False 1000 1000 0.00176969 0.0024947 1.40968 0.5 50 True 100 100 0.000346885 0.000263295 0.759028 0.5 50 True 100 1000 0.00073113 0.00107669 1.47263 0.5 50 True 1000 100 0.000672774 0.000493085 0.732914 0.5 50 True 1000 1000 0.00260436 0.003335 1.28054 0.5 50 False 100 100 0.00036242 0.000273196 0.753809 0.5 50 False 100 1000 0.000753295 0.00107086 1.42157 0.5 50 False 1000 100 0.00064886 0.000501654 0.773132 0.5 50 False 1000 1000 0.00241105 0.0033146 1.37475 0.5 100 True 100 100 0.000401269 0.00027831 0.693573 0.5 100 True 100 1000 0.00094245 0.00111468 1.18275 0.5 100 True 1000 100 0.00075719 0.00074962 0.990003 0.5 100 True 1000 1000 0.00297528 0.00601445 2.02147 0.5 100 False 100 100 0.000408576 0.00026246 0.642377 0.5 100 False 100 1000 0.00094272 0.00112762 1.19613 0.5 100 False 1000 100 0.000762925 0.00074343 0.974446 0.5 100 False 1000 1000 0.00314936 0.00604122 1.91824 0.8 50 True 100 100 0.00036589 0.000331376 0.905669 0.8 50 True 100 1000 0.00086403 0.00171248 1.98197 0.8 50 True 1000 100 0.00067048 0.000715261 1.06679 0.8 50 True 1000 1000 0.00284684 0.00527865 1.85422 0.8 50 False 100 100 0.000357161 0.000540144 1.51233 0.8 50 False 100 1000 0.000884765 0.00170428 1.92625 0.8 50 False 1000 100 0.000666975 0.000737065 1.10509 0.8 50 False 1000 1000 0.0028149 0.00530442 1.88441 0.8 100 True 100 100 0.00041237 0.00034323 0.832335 0.8 100 True 100 1000 0.00122102 0.00179725 1.47192 0.8 100 True 1000 100 0.000807976 0.00111246 1.37684 0.8 100 True 1000 1000 0.00379081 0.00968211 2.5541 0.8 100 False 100 100 0.000426315 0.000339085 0.795386 0.8 100 False 100 1000 0.00144096 0.00179819 1.2479 0.8 100 False 1000 100 0.000951196 0.0011155 1.17274 0.8 100 False 1000 1000 0.0039524 0.00980128 2.47983 Change: 155142876
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc65
-rw-r--r--tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h9
-rw-r--r--tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc146
-rw-r--r--tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py87
4 files changed, 146 insertions, 161 deletions
diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc
index 30026f222a..30c57ef287 100644
--- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc
+++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc
@@ -65,7 +65,8 @@ class SparseTensorDenseMatMulOp : public OpKernel {
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a_indices->shape()),
errors::InvalidArgument("Tensor 'a_indices' is not a matrix"));
- OP_REQUIRES(ctx, a_indices->shape().dim_size(0) == a_values->NumElements(),
+ const int64 nnz = a_indices->shape().dim_size(0);
+ OP_REQUIRES(ctx, nnz == a_values->NumElements(),
errors::InvalidArgument("Number of rows of a_indices does not "
"match number of entries in a_values"));
@@ -89,8 +90,28 @@ class SparseTensorDenseMatMulOp : public OpKernel {
inner_left, " vs. ", inner_right,
". Did you forget a transpose? "
"Dimensions of A: [",
- a_shape_t(0), ", ", a_shape_t(1), "). Dimensions of B: ",
- b->shape().DebugString()));
+ a_shape_t(0), ", ", a_shape_t(1),
+ "). Dimensions of B: ", b->shape().DebugString()));
+
+ if (std::is_same<Device, GPUDevice>::value) {
+ // The GPU implementation is optimized to use 32 bit indexing, so
+ // give a friendly error to the programmer early on if they
+ // exceed.
+ const int int32max = std::numeric_limits<int>::max();
+ OP_REQUIRES(
+ ctx,
+ (FastBoundsCheck(inner_left, int32max) &&
+ FastBoundsCheck(inner_right, int32max) &&
+ FastBoundsCheck(outer_left, int32max) &&
+ FastBoundsCheck(outer_right, int32max) &&
+ FastBoundsCheck(b->NumElements(), int32max) &&
+ FastBoundsCheck(outer_left * outer_right, int32max) &&
+ FastBoundsCheck(a_values->NumElements(), int32max)),
+ errors::InvalidArgument("Cannot use GPU for > 2^31 entry inputs"));
+ OP_REQUIRES(ctx, FastBoundsCheck(nnz * outer_right, int32max),
+ errors::InvalidArgument(
+ "Cannot use GPU when output.shape[1] * nnz(a) > 2^31"));
+ }
TensorShape out_shape({outer_left, outer_right});
Tensor* out = nullptr;
@@ -111,41 +132,13 @@ class SparseTensorDenseMatMulOp : public OpKernel {
return;
}
- Tensor scratch;
-
- if (std::is_same<Device, GPUDevice>::value) {
- // The GPU implementation is optimized to use 32 bit indexing, so
- // give a friendly error to the programmer early on if they exceed.
- OP_REQUIRES(
- ctx,
- FastBoundsCheck(inner_left, std::numeric_limits<int>::max()) &&
- FastBoundsCheck(inner_right, std::numeric_limits<int>::max()) &&
- FastBoundsCheck(outer_left, std::numeric_limits<int>::max()) &&
- FastBoundsCheck(outer_right, std::numeric_limits<int>::max()) &&
- FastBoundsCheck(b->NumElements(),
- std::numeric_limits<int>::max()) &&
- FastBoundsCheck(out->NumElements(),
- std::numeric_limits<int>::max()) &&
- FastBoundsCheck(a_values->NumElements(),
- std::numeric_limits<int>::max()),
- errors::InvalidArgument("Cannot use GPU for > 2^31 entry inputs"));
- const int nnz = static_cast<const int>(a_values->NumElements());
- // Need nnz length vec scratch space on the GPU.
- OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
- TensorShape({nnz}), &scratch));
- } else {
- // We don't need scratch space on the CPU.
- OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
- TensorShape({0}), &scratch));
- }
-
#define MAYBE_ADJOINT(ADJ_A, ADJ_B) \
if (adjoint_a_ == ADJ_A && adjoint_b_ == ADJ_B) { \
Status functor_status = functor::SparseTensorDenseMatMulFunctor< \
Device, T, Tindices, ADJ_A, \
ADJ_B>::Compute(ctx->eigen_device<Device>(), out->matrix<T>(), \
a_indices->matrix<Tindices>(), a_values->vec<T>(), \
- b->matrix<T>(), scratch.vec<T>()); \
+ b->matrix<T>()); \
OP_REQUIRES_OK(ctx, functor_status); \
}
@@ -189,10 +182,9 @@ namespace functor {
Status SparseTensorDenseMatMulFunctor< \
GPUDevice, T, Tindices, ADJ_A, \
ADJ_B>::Compute(const GPUDevice& d, typename TTypes<T>::Matrix out, \
- typename TTypes<Tindices>::ConstMatrix a_indices, \
+ TTypes<Tindices>::ConstMatrix a_indices, \
typename TTypes<T>::ConstVec a_values, \
- typename TTypes<T>::ConstMatrix b, \
- typename TTypes<T>::Vec scratch); \
+ typename TTypes<T>::ConstMatrix b); \
extern template struct SparseTensorDenseMatMulFunctor< \
GPUDevice, T, Tindices, ADJ_A, ADJ_B>;
@@ -255,8 +247,7 @@ struct SparseTensorDenseMatMulFunctor<CPUDevice, T, Tindices, ADJ_A, ADJ_B> {
static Status Compute(const CPUDevice& d, typename TTypes<T>::Matrix out,
typename TTypes<Tindices>::ConstMatrix a_indices,
typename TTypes<T>::ConstVec a_values,
- typename TTypes<T>::ConstMatrix b,
- typename TTypes<T>::Vec scratch) {
+ typename TTypes<T>::ConstMatrix b) {
const std::size_t nnz = a_values.size();
const std::size_t rhs_right = (ADJ_B ? b.dimension(0) : b.dimension(1));
const std::size_t lhs_right = (ADJ_B ? b.dimension(1) : b.dimension(0));
diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h
index e707743f78..da13190494 100644
--- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h
+++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h
@@ -28,11 +28,10 @@ namespace functor {
template <typename Device, typename T, typename Tindices, bool ADJ_A,
bool ADJ_B>
struct SparseTensorDenseMatMulFunctor {
- static EIGEN_ALWAYS_INLINE Status
- Compute(const Device& d, typename TTypes<T>::Matrix out,
- typename TTypes<Tindices>::ConstMatrix a_indices,
- typename TTypes<T>::ConstVec a_values,
- typename TTypes<T>::ConstMatrix b, typename TTypes<T>::Vec scratch);
+ static EIGEN_ALWAYS_INLINE Status Compute(
+ const Device& d, typename TTypes<T>::Matrix out,
+ typename TTypes<Tindices>::ConstMatrix a_indices,
+ typename TTypes<T>::ConstVec a_values, typename TTypes<T>::ConstMatrix b);
};
template <typename MATRIX, bool ADJ>
diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc
index 7266e0cf81..e261e42e0d 100644
--- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc
@@ -20,71 +20,45 @@ limitations under the License.
#include "tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h"
#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/util/cuda_kernel_helper.h"
namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
-namespace generator {
-
template <typename T, typename Tindices, bool ADJ_A, bool ADJ_B>
-class SparseTensorDenseMatMulGPUGenerator {
- public:
- EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SparseTensorDenseMatMulGPUGenerator(
- typename TTypes<T, 2>::Tensor32Bit out,
- typename TTypes<const Tindices, 2>::Tensor32Bit a_indices,
- typename TTypes<const T, 1>::Tensor32Bit a_values,
- typename TTypes<const T, 2>::Tensor32Bit b)
- : out_(out),
- lhs_index_a_(ADJ_A ? 1 : 0),
- rhs_index_a_(ADJ_A ? 0 : 1),
- a_indices_(a_indices),
- a_values_(a_values),
- lhs_right_size(ADJ_B ? b.dimension(1) : b.dimension(0)),
- maybe_adjoint_b_(
- functor::MaybeAdjoint<typename TTypes<const T, 2>::Tensor32Bit,
- ADJ_B>(b)) {}
-
- EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
- operator()(const Eigen::array<int, 2>& j_and_ix) const {
-#ifdef __CUDA_ARCH__
- const int j = j_and_ix[0];
- const int ix = j_and_ix[1];
- int m = a_indices_(ix, lhs_index_a_);
- int k = a_indices_(ix, rhs_index_a_);
- assert(k < lhs_right_size);
- assert(m < out_.dimension(0));
- // If asserts are disabled, the caller is violating the sparse
- // tensor index contract, and so we return invalid results.
- // Force returning NaNs to try to signal that something is amiss.
- T b_value;
- if (k >= lhs_right_size || m >= out_.dimension(0)) {
- m = 0;
- k = 0;
- b_value = std::numeric_limits<T>::quiet_NaN();
- } else {
- b_value = maybe_adjoint_b_(k, j);
+__global__ void SparseTensorDenseMatMulKernel(int nnz, int m, int b_rows,
+ int b_cols, int p,
+ const Tindices* a_indices,
+ const T* a_values, const T* b,
+ T* out) {
+ // out_{ij} = sum_k {a_ik b_kj}
+ // out = A * B', out_{ij} = sum_k {a_ik (b')_kj}; b'_{kj} = b_{jk}
+ const int n = (ADJ_B) ? b_cols : b_rows;
+ CUDA_1D_KERNEL_LOOP(index, nnz * p) {
+ const int a_ix = index / p;
+ const int j = index % p;
+ const int i = ldg(a_indices + 2 * a_ix + ((ADJ_A) ? 1 : 0));
+ const int k = ldg(a_indices + 2 * a_ix + ((ADJ_A) ? 0 : 1));
+ if (!FastBoundsCheck(i, m)) {
+ continue; // Nowhere to signal an error :(
+ }
+ // out[i, j]
+ T* out_location = out + i * p + j;
+ if (!FastBoundsCheck(k, n)) {
+ CudaAtomicAdd(out_location, std::numeric_limits<T>::quiet_NaN());
+ continue;
}
- atomicAdd(&out_(m, j), a_values_(ix) * b_value);
-#else
- assert(false && "This should only be run on the device");
-#endif
- // Return something
- return T(0);
- }
- private:
- mutable typename TTypes<T, 2>::Tensor32Bit out_;
- const int lhs_index_a_;
- const int rhs_index_a_;
- typename TTypes<const Tindices, 2>::Tensor32Bit a_indices_;
- typename TTypes<const T, 1>::Tensor32Bit a_values_;
- const int lhs_right_size;
- functor::MaybeAdjoint<typename TTypes<const T, 2>::Tensor32Bit, ADJ_B>
- maybe_adjoint_b_;
-};
+ // a_value == (ADJ_A) ? a[k, i] : a[i, k]
+ const T a_value = ldg(a_values + a_ix);
-} // namespace generator
+ // b_value == (ADJ_B) ? b[j, k] : b[k, j]
+ const T b_value = ldg(b + ((ADJ_B) ? j * b_cols + k : k * b_cols + j));
+ CudaAtomicAdd(out_location, a_value * b_value);
+ }
+}
namespace functor {
@@ -94,51 +68,23 @@ struct SparseTensorDenseMatMulFunctor<GPUDevice, T, Tindices, ADJ_A, ADJ_B> {
Compute(const GPUDevice& d, typename TTypes<T>::Matrix out,
typename TTypes<Tindices>::ConstMatrix a_indices,
typename TTypes<T>::ConstVec a_values,
- typename TTypes<T>::ConstMatrix b, typename TTypes<T>::Vec scratch) {
- generator::SparseTensorDenseMatMulGPUGenerator<T, Tindices, ADJ_A, ADJ_B>
- sparse_tensor_dense_matmul_generator(To32Bit(out), To32Bit(a_indices),
- To32Bit(a_values), To32Bit(b));
- To32Bit(out).device(d) = To32Bit(out).constant(T(0));
+ typename TTypes<T>::ConstMatrix b) {
+ out.device(d) = out.constant(T(0));
int nnz = a_values.size();
- int n = (ADJ_B) ? b.dimension(0) : b.dimension(1);
-
-#if !defined(EIGEN_HAS_INDEX_LIST)
- Eigen::Tensor<int, 2>::Dimensions matrix_1_by_nnz{{ 1, nnz }};
- Eigen::array<int, 2> n_by_1{{ n, 1 }};
- Eigen::array<int, 1> reduce_on_rows{{ 0 }};
-#else
- Eigen::IndexList<Eigen::type2index<1>, int> matrix_1_by_nnz;
- matrix_1_by_nnz.set(1, nnz);
- Eigen::IndexList<int, Eigen::type2index<1> > n_by_1;
- n_by_1.set(0, n);
- Eigen::IndexList<Eigen::type2index<0> > reduce_on_rows;
-#endif
-
- // How this works: the generator iterates over (j, ix) where j
- // iterates from 0 .. n - 1 and ix iterates from
- // 0 .. nnz - 1. A side effect of the generator is to accumulate
- // the products of values in A and B into the appropriate location
- // in the dense matrix out. In order to run the iteration,
- // we take a smaller variable and broadcast to a size (n, nnz).
- // This is the scratch variable. In order to enforce execution,
- // we have to perform assignment back into scratch (taking the sum).
- // We don't care what gets assigned to scratch - only the side effect
- // of the execution in the generator.
- //
- // Note it's not sufficient that scratch be a scalar, and to
- // broadcast it to a matrix. Eigen splits the computation not
- // based on the largest intermediate shape (the size of the
- // broadcast of scratch) but based on the output shape. So
- // scratch needs to be a vector at least.
- //
- // Note also that only float type is supported because the
- // atomicAdd operation is only supported for floats in hardware.
- To32Bit(scratch).device(d) =
- To32Bit(scratch)
- .reshape(matrix_1_by_nnz)
- .broadcast(n_by_1)
- .generate(sparse_tensor_dense_matmul_generator)
- .sum(reduce_on_rows);
+ // out = A * B, A is [m x n] and B is [n x p], out is [m x p]
+ int m = out.dimension(0);
+ int p = out.dimension(1);
+ int b_rows = b.dimension(0);
+ int b_cols = b.dimension(1);
+
+ // TODO(ebrevdo): Should this be alpha * nnz instead of
+ // out.size()? Perhaps p * nnz ?
+ CudaLaunchConfig config = GetCudaLaunchConfig(p * nnz, d);
+
+ SparseTensorDenseMatMulKernel<T, Tindices, ADJ_A, ADJ_B>
+ <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
+ nnz, m, b_rows, b_cols, p, a_indices.data(), a_values.data(),
+ b.data(), out.data());
return Status::OK();
}
diff --git a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py
index 8099175186..a0bd178e24 100644
--- a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py
@@ -29,6 +29,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
@@ -161,6 +162,46 @@ class SparseTensorDenseMatMulTest(test.TestCase):
sparse_ops.sparse_tensor_dense_matmul(
sparse_t, dense_t, adjoint_a=True).eval()
+ def testInvalidIndicesForSparseTensorDenseMatmulOnGPU(self):
+ # Note: use_gpu=False because nice errors are only returned from CPU kerne
+ if not test.is_gpu_available():
+ return
+ with self.test_session(use_gpu=True):
+ indices = np.array([[1, 10]]).astype(np.int64)
+ values = np.array([10]).astype(np.float32)
+ shape = [3, 2]
+ sparse_t = sparse_tensor.SparseTensor(indices, values, shape)
+
+ # Test multiplying by both a small and large dense matrix, to hit
+ # both cases in the kernel.
+ dense_t = np.matrix([[1] * 5, [2] * 5], dtype=np.float32)
+ expected_t = np.array([[0] * 5, [np.nan] * 5, [0] * 5], dtype=np.float32)
+ self.assertAllClose(expected_t,
+ sparse_ops.sparse_tensor_dense_matmul(
+ sparse_t, dense_t).eval())
+ dense_t = np.matrix([[1] * 500, [2] * 500], dtype=np.float32)
+ expected_t = np.array(
+ [[0] * 500, [np.nan] * 500, [0] * 500], dtype=np.float32)
+ self.assertAllClose(expected_t,
+ sparse_ops.sparse_tensor_dense_matmul(
+ sparse_t, dense_t).eval())
+
+ # Repeat with adjoint_a, now the error is that the sparse index
+ # is OOO w.r.t. the output. The GPU kernel can't do much here,
+ # so it just doesn't accumulate.
+
+ dense_t = np.matrix([[1] * 5, [2] * 5, [3] * 5], dtype=np.float32)
+ expected_t = np.array([[0] * 5, [0] * 5], dtype=np.float32)
+ self.assertAllClose(expected_t,
+ sparse_ops.sparse_tensor_dense_matmul(
+ sparse_t, dense_t, adjoint_a=True).eval())
+
+ dense_t = np.matrix([[1] * 500, [2] * 500, [3] * 500], dtype=np.float32)
+ expected_t = np.array([[0] * 500, [0] * 500], dtype=np.float32)
+ self.assertAllClose(expected_t,
+ sparse_ops.sparse_tensor_dense_matmul(
+ sparse_t, dense_t, adjoint_a=True).eval())
+
# Tests setting one dimension to be a high value.
def _testLarge(self, np_dtype):
r1 = np.random.randint(6000, 20000)
@@ -175,9 +216,12 @@ class SparseTensorDenseMatMulTest(test.TestCase):
y = _maybe_complex(np.random.randn(k, n).astype(np_dtype))
- self._testMatmul(x, y)
+ self._testMatmul(x, y, adjoint_a=False, adjoint_b=False)
+ self._testMatmul(x.transpose(), y, adjoint_a=True, adjoint_b=False)
+ self._testMatmul(x, y.transpose(), adjoint_a=False, adjoint_b=True)
+ self._testMatmul(
+ x.transpose(), y.transpose(), adjoint_a=True, adjoint_b=True)
- def testLarge(self):
np.random.seed(127) # Repeatable results
self._testLarge(np.float32)
self._testLarge(np.float64)
@@ -221,7 +265,9 @@ def _sparse_tensor_dense_vs_dense_matmul_benchmark_dense(x, y, adjoint_a,
lambda t, _: t < iterations,
body, (t0, v0),
parallel_iterations=1,
- back_prop=False)
+ back_prop=False,
+ shape_invariants=(tensor_shape.TensorShape(()),
+ tensor_shape.TensorShape(None)))
return [final]
return _timeit
@@ -246,7 +292,9 @@ def _sparse_tensor_dense_vs_dense_matmul_benchmark_sparse(x_ind, x_val, x_shape,
lambda t, _: t < iterations,
body, (t0, v0),
parallel_iterations=1,
- back_prop=False)
+ back_prop=False,
+ shape_invariants=(tensor_shape.TensorShape(()),
+ tensor_shape.TensorShape(None)))
return [final]
return _timeit
@@ -291,7 +339,7 @@ def sparse_tensor_dense_vs_dense_matmul_benchmark(thresh,
if skip_dense:
delta_dense = float("nan")
else:
- with session.Session("", config=config, graph=ops.Graph()) as sess:
+ with session.Session(config=config, graph=ops.Graph()) as sess:
if not use_gpu:
with ops.device("/cpu:0"):
x_t = constant_op.constant(x)
@@ -299,12 +347,12 @@ def sparse_tensor_dense_vs_dense_matmul_benchmark(thresh,
ops_fn = _sparse_tensor_dense_vs_dense_matmul_benchmark_dense(
x_t, y_t, adjoint_a, adjoint_b)
else:
- x_t = constant_op.constant(x)
- y_t = constant_op.constant(y)
- ops_fn = _sparse_tensor_dense_vs_dense_matmul_benchmark_dense(x_t, y_t,
- adjoint_a,
- adjoint_b)
- delta_dense = _timer(sess, ops_fn, 1000)
+ with ops.device("/gpu:0"):
+ x_t = constant_op.constant(x)
+ y_t = constant_op.constant(y)
+ ops_fn = _sparse_tensor_dense_vs_dense_matmul_benchmark_dense(
+ x_t, y_t, adjoint_a, adjoint_b)
+ delta_dense = _timer(sess, ops_fn, 200)
# Using sparse_tensor_dense_matmul.
with session.Session("", config=config, graph=ops.Graph()) as sess:
@@ -317,13 +365,14 @@ def sparse_tensor_dense_vs_dense_matmul_benchmark(thresh,
ops_fn = _sparse_tensor_dense_vs_dense_matmul_benchmark_sparse(
x_ind, x_val, x_shape, y_t, adjoint_a, adjoint_b)
else:
- x_ind = constant_op.constant(np.vstack(np.where(x)).astype(np.int64).T)
- x_val = constant_op.constant(x[np.where(x)])
- x_shape = constant_op.constant(np.array(x.shape).astype(np.int64))
- y_t = constant_op.constant(y)
- ops_fn = _sparse_tensor_dense_vs_dense_matmul_benchmark_sparse(
- x_ind, x_val, x_shape, y_t, adjoint_a, adjoint_b)
- delta_sparse = _timer(sess, ops_fn, 1000)
+ with ops.device("/gpu:0"):
+ x_ind = constant_op.constant(np.vstack(np.where(x)).astype(np.int64).T)
+ x_val = constant_op.constant(x[np.where(x)])
+ x_shape = constant_op.constant(np.array(x.shape).astype(np.int64))
+ y_t = constant_op.constant(y)
+ ops_fn = _sparse_tensor_dense_vs_dense_matmul_benchmark_sparse(
+ x_ind, x_val, x_shape, y_t, adjoint_a, adjoint_b)
+ delta_sparse = _timer(sess, ops_fn, 200)
print("%g \t %d \t %s \t %d \t %d \t %g \t %g \t %g" %
(1 - thresh, n, use_gpu, m, k, delta_dense, delta_sparse,
@@ -340,7 +389,7 @@ def main(_):
"\t dt(sparse)/dt(dense)")
for thresh in (0.99, 0.8, 0.5, 0.2):
- for n in (1, 10, 25):
+ for n in (50, 100):
for use_gpu in (True, False):
for m in (100, 1000):
for k in (100, 1000):