aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc65
-rw-r--r--tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h9
-rw-r--r--tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc146
-rw-r--r--tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py87
4 files changed, 146 insertions, 161 deletions
diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc
index 30026f222a..30c57ef287 100644
--- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc
+++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc
@@ -65,7 +65,8 @@ class SparseTensorDenseMatMulOp : public OpKernel {
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a_indices->shape()),
errors::InvalidArgument("Tensor 'a_indices' is not a matrix"));
- OP_REQUIRES(ctx, a_indices->shape().dim_size(0) == a_values->NumElements(),
+ const int64 nnz = a_indices->shape().dim_size(0);
+ OP_REQUIRES(ctx, nnz == a_values->NumElements(),
errors::InvalidArgument("Number of rows of a_indices does not "
"match number of entries in a_values"));
@@ -89,8 +90,28 @@ class SparseTensorDenseMatMulOp : public OpKernel {
inner_left, " vs. ", inner_right,
". Did you forget a transpose? "
"Dimensions of A: [",
- a_shape_t(0), ", ", a_shape_t(1), "). Dimensions of B: ",
- b->shape().DebugString()));
+ a_shape_t(0), ", ", a_shape_t(1),
+ "). Dimensions of B: ", b->shape().DebugString()));
+
+ if (std::is_same<Device, GPUDevice>::value) {
+ // The GPU implementation is optimized to use 32 bit indexing, so
+ // give a friendly error to the programmer early on if they
+ // exceed.
+ const int int32max = std::numeric_limits<int>::max();
+ OP_REQUIRES(
+ ctx,
+ (FastBoundsCheck(inner_left, int32max) &&
+ FastBoundsCheck(inner_right, int32max) &&
+ FastBoundsCheck(outer_left, int32max) &&
+ FastBoundsCheck(outer_right, int32max) &&
+ FastBoundsCheck(b->NumElements(), int32max) &&
+ FastBoundsCheck(outer_left * outer_right, int32max) &&
+ FastBoundsCheck(a_values->NumElements(), int32max)),
+ errors::InvalidArgument("Cannot use GPU for > 2^31 entry inputs"));
+ OP_REQUIRES(ctx, FastBoundsCheck(nnz * outer_right, int32max),
+ errors::InvalidArgument(
+ "Cannot use GPU when output.shape[1] * nnz(a) > 2^31"));
+ }
TensorShape out_shape({outer_left, outer_right});
Tensor* out = nullptr;
@@ -111,41 +132,13 @@ class SparseTensorDenseMatMulOp : public OpKernel {
return;
}
- Tensor scratch;
-
- if (std::is_same<Device, GPUDevice>::value) {
- // The GPU implementation is optimized to use 32 bit indexing, so
- // give a friendly error to the programmer early on if they exceed.
- OP_REQUIRES(
- ctx,
- FastBoundsCheck(inner_left, std::numeric_limits<int>::max()) &&
- FastBoundsCheck(inner_right, std::numeric_limits<int>::max()) &&
- FastBoundsCheck(outer_left, std::numeric_limits<int>::max()) &&
- FastBoundsCheck(outer_right, std::numeric_limits<int>::max()) &&
- FastBoundsCheck(b->NumElements(),
- std::numeric_limits<int>::max()) &&
- FastBoundsCheck(out->NumElements(),
- std::numeric_limits<int>::max()) &&
- FastBoundsCheck(a_values->NumElements(),
- std::numeric_limits<int>::max()),
- errors::InvalidArgument("Cannot use GPU for > 2^31 entry inputs"));
- const int nnz = static_cast<const int>(a_values->NumElements());
- // Need nnz length vec scratch space on the GPU.
- OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
- TensorShape({nnz}), &scratch));
- } else {
- // We don't need scratch space on the CPU.
- OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
- TensorShape({0}), &scratch));
- }
-
#define MAYBE_ADJOINT(ADJ_A, ADJ_B) \
if (adjoint_a_ == ADJ_A && adjoint_b_ == ADJ_B) { \
Status functor_status = functor::SparseTensorDenseMatMulFunctor< \
Device, T, Tindices, ADJ_A, \
ADJ_B>::Compute(ctx->eigen_device<Device>(), out->matrix<T>(), \
a_indices->matrix<Tindices>(), a_values->vec<T>(), \
- b->matrix<T>(), scratch.vec<T>()); \
+ b->matrix<T>()); \
OP_REQUIRES_OK(ctx, functor_status); \
}
@@ -189,10 +182,9 @@ namespace functor {
Status SparseTensorDenseMatMulFunctor< \
GPUDevice, T, Tindices, ADJ_A, \
ADJ_B>::Compute(const GPUDevice& d, typename TTypes<T>::Matrix out, \
- typename TTypes<Tindices>::ConstMatrix a_indices, \
+ TTypes<Tindices>::ConstMatrix a_indices, \
typename TTypes<T>::ConstVec a_values, \
- typename TTypes<T>::ConstMatrix b, \
- typename TTypes<T>::Vec scratch); \
+ typename TTypes<T>::ConstMatrix b); \
extern template struct SparseTensorDenseMatMulFunctor< \
GPUDevice, T, Tindices, ADJ_A, ADJ_B>;
@@ -255,8 +247,7 @@ struct SparseTensorDenseMatMulFunctor<CPUDevice, T, Tindices, ADJ_A, ADJ_B> {
static Status Compute(const CPUDevice& d, typename TTypes<T>::Matrix out,
typename TTypes<Tindices>::ConstMatrix a_indices,
typename TTypes<T>::ConstVec a_values,
- typename TTypes<T>::ConstMatrix b,
- typename TTypes<T>::Vec scratch) {
+ typename TTypes<T>::ConstMatrix b) {
const std::size_t nnz = a_values.size();
const std::size_t rhs_right = (ADJ_B ? b.dimension(0) : b.dimension(1));
const std::size_t lhs_right = (ADJ_B ? b.dimension(1) : b.dimension(0));
diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h
index e707743f78..da13190494 100644
--- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h
+++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h
@@ -28,11 +28,10 @@ namespace functor {
template <typename Device, typename T, typename Tindices, bool ADJ_A,
bool ADJ_B>
struct SparseTensorDenseMatMulFunctor {
- static EIGEN_ALWAYS_INLINE Status
- Compute(const Device& d, typename TTypes<T>::Matrix out,
- typename TTypes<Tindices>::ConstMatrix a_indices,
- typename TTypes<T>::ConstVec a_values,
- typename TTypes<T>::ConstMatrix b, typename TTypes<T>::Vec scratch);
+ static EIGEN_ALWAYS_INLINE Status Compute(
+ const Device& d, typename TTypes<T>::Matrix out,
+ typename TTypes<Tindices>::ConstMatrix a_indices,
+ typename TTypes<T>::ConstVec a_values, typename TTypes<T>::ConstMatrix b);
};
template <typename MATRIX, bool ADJ>
diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc
index 7266e0cf81..e261e42e0d 100644
--- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc
@@ -20,71 +20,45 @@ limitations under the License.
#include "tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h"
#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/util/cuda_kernel_helper.h"
namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
-namespace generator {
-
template <typename T, typename Tindices, bool ADJ_A, bool ADJ_B>
-class SparseTensorDenseMatMulGPUGenerator {
- public:
- EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SparseTensorDenseMatMulGPUGenerator(
- typename TTypes<T, 2>::Tensor32Bit out,
- typename TTypes<const Tindices, 2>::Tensor32Bit a_indices,
- typename TTypes<const T, 1>::Tensor32Bit a_values,
- typename TTypes<const T, 2>::Tensor32Bit b)
- : out_(out),
- lhs_index_a_(ADJ_A ? 1 : 0),
- rhs_index_a_(ADJ_A ? 0 : 1),
- a_indices_(a_indices),
- a_values_(a_values),
- lhs_right_size(ADJ_B ? b.dimension(1) : b.dimension(0)),
- maybe_adjoint_b_(
- functor::MaybeAdjoint<typename TTypes<const T, 2>::Tensor32Bit,
- ADJ_B>(b)) {}
-
- EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
- operator()(const Eigen::array<int, 2>& j_and_ix) const {
-#ifdef __CUDA_ARCH__
- const int j = j_and_ix[0];
- const int ix = j_and_ix[1];
- int m = a_indices_(ix, lhs_index_a_);
- int k = a_indices_(ix, rhs_index_a_);
- assert(k < lhs_right_size);
- assert(m < out_.dimension(0));
- // If asserts are disabled, the caller is violating the sparse
- // tensor index contract, and so we return invalid results.
- // Force returning NaNs to try to signal that something is amiss.
- T b_value;
- if (k >= lhs_right_size || m >= out_.dimension(0)) {
- m = 0;
- k = 0;
- b_value = std::numeric_limits<T>::quiet_NaN();
- } else {
- b_value = maybe_adjoint_b_(k, j);
+__global__ void SparseTensorDenseMatMulKernel(int nnz, int m, int b_rows,
+ int b_cols, int p,
+ const Tindices* a_indices,
+ const T* a_values, const T* b,
+ T* out) {
+ // out_{ij} = sum_k {a_ik b_kj}
+ // out = A * B', out_{ij} = sum_k {a_ik (b')_kj}; b'_{kj} = b_{jk}
+ const int n = (ADJ_B) ? b_cols : b_rows;
+ CUDA_1D_KERNEL_LOOP(index, nnz * p) {
+ const int a_ix = index / p;
+ const int j = index % p;
+ const int i = ldg(a_indices + 2 * a_ix + ((ADJ_A) ? 1 : 0));
+ const int k = ldg(a_indices + 2 * a_ix + ((ADJ_A) ? 0 : 1));
+ if (!FastBoundsCheck(i, m)) {
+ continue; // Nowhere to signal an error :(
+ }
+ // out[i, j]
+ T* out_location = out + i * p + j;
+ if (!FastBoundsCheck(k, n)) {
+ CudaAtomicAdd(out_location, std::numeric_limits<T>::quiet_NaN());
+ continue;
}
- atomicAdd(&out_(m, j), a_values_(ix) * b_value);
-#else
- assert(false && "This should only be run on the device");
-#endif
- // Return something
- return T(0);
- }
- private:
- mutable typename TTypes<T, 2>::Tensor32Bit out_;
- const int lhs_index_a_;
- const int rhs_index_a_;
- typename TTypes<const Tindices, 2>::Tensor32Bit a_indices_;
- typename TTypes<const T, 1>::Tensor32Bit a_values_;
- const int lhs_right_size;
- functor::MaybeAdjoint<typename TTypes<const T, 2>::Tensor32Bit, ADJ_B>
- maybe_adjoint_b_;
-};
+ // a_value == (ADJ_A) ? a[k, i] : a[i, k]
+ const T a_value = ldg(a_values + a_ix);
-} // namespace generator
+ // b_value == (ADJ_B) ? b[j, k] : b[k, j]
+ const T b_value = ldg(b + ((ADJ_B) ? j * b_cols + k : k * b_cols + j));
+ CudaAtomicAdd(out_location, a_value * b_value);
+ }
+}
namespace functor {
@@ -94,51 +68,23 @@ struct SparseTensorDenseMatMulFunctor<GPUDevice, T, Tindices, ADJ_A, ADJ_B> {
Compute(const GPUDevice& d, typename TTypes<T>::Matrix out,
typename TTypes<Tindices>::ConstMatrix a_indices,
typename TTypes<T>::ConstVec a_values,
- typename TTypes<T>::ConstMatrix b, typename TTypes<T>::Vec scratch) {
- generator::SparseTensorDenseMatMulGPUGenerator<T, Tindices, ADJ_A, ADJ_B>
- sparse_tensor_dense_matmul_generator(To32Bit(out), To32Bit(a_indices),
- To32Bit(a_values), To32Bit(b));
- To32Bit(out).device(d) = To32Bit(out).constant(T(0));
+ typename TTypes<T>::ConstMatrix b) {
+ out.device(d) = out.constant(T(0));
int nnz = a_values.size();
- int n = (ADJ_B) ? b.dimension(0) : b.dimension(1);
-
-#if !defined(EIGEN_HAS_INDEX_LIST)
- Eigen::Tensor<int, 2>::Dimensions matrix_1_by_nnz{{ 1, nnz }};
- Eigen::array<int, 2> n_by_1{{ n, 1 }};
- Eigen::array<int, 1> reduce_on_rows{{ 0 }};
-#else
- Eigen::IndexList<Eigen::type2index<1>, int> matrix_1_by_nnz;
- matrix_1_by_nnz.set(1, nnz);
- Eigen::IndexList<int, Eigen::type2index<1> > n_by_1;
- n_by_1.set(0, n);
- Eigen::IndexList<Eigen::type2index<0> > reduce_on_rows;
-#endif
-
- // How this works: the generator iterates over (j, ix) where j
- // iterates from 0 .. n - 1 and ix iterates from
- // 0 .. nnz - 1. A side effect of the generator is to accumulate
- // the products of values in A and B into the appropriate location
- // in the dense matrix out. In order to run the iteration,
- // we take a smaller variable and broadcast to a size (n, nnz).
- // This is the scratch variable. In order to enforce execution,
- // we have to perform assignment back into scratch (taking the sum).
- // We don't care what gets assigned to scratch - only the side effect
- // of the execution in the generator.
- //
- // Note it's not sufficient that scratch be a scalar, and to
- // broadcast it to a matrix. Eigen splits the computation not
- // based on the largest intermediate shape (the size of the
- // broadcast of scratch) but based on the output shape. So
- // scratch needs to be a vector at least.
- //
- // Note also that only float type is supported because the
- // atomicAdd operation is only supported for floats in hardware.
- To32Bit(scratch).device(d) =
- To32Bit(scratch)
- .reshape(matrix_1_by_nnz)
- .broadcast(n_by_1)
- .generate(sparse_tensor_dense_matmul_generator)
- .sum(reduce_on_rows);
+ // out = A * B, A is [m x n] and B is [n x p], out is [m x p]
+ int m = out.dimension(0);
+ int p = out.dimension(1);
+ int b_rows = b.dimension(0);
+ int b_cols = b.dimension(1);
+
+ // TODO(ebrevdo): Should this be alpha * nnz instead of
+ // out.size()? Perhaps p * nnz ?
+ CudaLaunchConfig config = GetCudaLaunchConfig(p * nnz, d);
+
+ SparseTensorDenseMatMulKernel<T, Tindices, ADJ_A, ADJ_B>
+ <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
+ nnz, m, b_rows, b_cols, p, a_indices.data(), a_values.data(),
+ b.data(), out.data());
return Status::OK();
}
diff --git a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py
index 8099175186..a0bd178e24 100644
--- a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py
@@ -29,6 +29,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
@@ -161,6 +162,46 @@ class SparseTensorDenseMatMulTest(test.TestCase):
sparse_ops.sparse_tensor_dense_matmul(
sparse_t, dense_t, adjoint_a=True).eval()
+ def testInvalidIndicesForSparseTensorDenseMatmulOnGPU(self):
+ # Note: use_gpu=False because nice errors are only returned from CPU kerne
+ if not test.is_gpu_available():
+ return
+ with self.test_session(use_gpu=True):
+ indices = np.array([[1, 10]]).astype(np.int64)
+ values = np.array([10]).astype(np.float32)
+ shape = [3, 2]
+ sparse_t = sparse_tensor.SparseTensor(indices, values, shape)
+
+ # Test multiplying by both a small and large dense matrix, to hit
+ # both cases in the kernel.
+ dense_t = np.matrix([[1] * 5, [2] * 5], dtype=np.float32)
+ expected_t = np.array([[0] * 5, [np.nan] * 5, [0] * 5], dtype=np.float32)
+ self.assertAllClose(expected_t,
+ sparse_ops.sparse_tensor_dense_matmul(
+ sparse_t, dense_t).eval())
+ dense_t = np.matrix([[1] * 500, [2] * 500], dtype=np.float32)
+ expected_t = np.array(
+ [[0] * 500, [np.nan] * 500, [0] * 500], dtype=np.float32)
+ self.assertAllClose(expected_t,
+ sparse_ops.sparse_tensor_dense_matmul(
+ sparse_t, dense_t).eval())
+
+ # Repeat with adjoint_a, now the error is that the sparse index
+ # is OOO w.r.t. the output. The GPU kernel can't do much here,
+ # so it just doesn't accumulate.
+
+ dense_t = np.matrix([[1] * 5, [2] * 5, [3] * 5], dtype=np.float32)
+ expected_t = np.array([[0] * 5, [0] * 5], dtype=np.float32)
+ self.assertAllClose(expected_t,
+ sparse_ops.sparse_tensor_dense_matmul(
+ sparse_t, dense_t, adjoint_a=True).eval())
+
+ dense_t = np.matrix([[1] * 500, [2] * 500, [3] * 500], dtype=np.float32)
+ expected_t = np.array([[0] * 500, [0] * 500], dtype=np.float32)
+ self.assertAllClose(expected_t,
+ sparse_ops.sparse_tensor_dense_matmul(
+ sparse_t, dense_t, adjoint_a=True).eval())
+
# Tests setting one dimension to be a high value.
def _testLarge(self, np_dtype):
r1 = np.random.randint(6000, 20000)
@@ -175,9 +216,12 @@ class SparseTensorDenseMatMulTest(test.TestCase):
y = _maybe_complex(np.random.randn(k, n).astype(np_dtype))
- self._testMatmul(x, y)
+ self._testMatmul(x, y, adjoint_a=False, adjoint_b=False)
+ self._testMatmul(x.transpose(), y, adjoint_a=True, adjoint_b=False)
+ self._testMatmul(x, y.transpose(), adjoint_a=False, adjoint_b=True)
+ self._testMatmul(
+ x.transpose(), y.transpose(), adjoint_a=True, adjoint_b=True)
- def testLarge(self):
np.random.seed(127) # Repeatable results
self._testLarge(np.float32)
self._testLarge(np.float64)
@@ -221,7 +265,9 @@ def _sparse_tensor_dense_vs_dense_matmul_benchmark_dense(x, y, adjoint_a,
lambda t, _: t < iterations,
body, (t0, v0),
parallel_iterations=1,
- back_prop=False)
+ back_prop=False,
+ shape_invariants=(tensor_shape.TensorShape(()),
+ tensor_shape.TensorShape(None)))
return [final]
return _timeit
@@ -246,7 +292,9 @@ def _sparse_tensor_dense_vs_dense_matmul_benchmark_sparse(x_ind, x_val, x_shape,
lambda t, _: t < iterations,
body, (t0, v0),
parallel_iterations=1,
- back_prop=False)
+ back_prop=False,
+ shape_invariants=(tensor_shape.TensorShape(()),
+ tensor_shape.TensorShape(None)))
return [final]
return _timeit
@@ -291,7 +339,7 @@ def sparse_tensor_dense_vs_dense_matmul_benchmark(thresh,
if skip_dense:
delta_dense = float("nan")
else:
- with session.Session("", config=config, graph=ops.Graph()) as sess:
+ with session.Session(config=config, graph=ops.Graph()) as sess:
if not use_gpu:
with ops.device("/cpu:0"):
x_t = constant_op.constant(x)
@@ -299,12 +347,12 @@ def sparse_tensor_dense_vs_dense_matmul_benchmark(thresh,
ops_fn = _sparse_tensor_dense_vs_dense_matmul_benchmark_dense(
x_t, y_t, adjoint_a, adjoint_b)
else:
- x_t = constant_op.constant(x)
- y_t = constant_op.constant(y)
- ops_fn = _sparse_tensor_dense_vs_dense_matmul_benchmark_dense(x_t, y_t,
- adjoint_a,
- adjoint_b)
- delta_dense = _timer(sess, ops_fn, 1000)
+ with ops.device("/gpu:0"):
+ x_t = constant_op.constant(x)
+ y_t = constant_op.constant(y)
+ ops_fn = _sparse_tensor_dense_vs_dense_matmul_benchmark_dense(
+ x_t, y_t, adjoint_a, adjoint_b)
+ delta_dense = _timer(sess, ops_fn, 200)
# Using sparse_tensor_dense_matmul.
with session.Session("", config=config, graph=ops.Graph()) as sess:
@@ -317,13 +365,14 @@ def sparse_tensor_dense_vs_dense_matmul_benchmark(thresh,
ops_fn = _sparse_tensor_dense_vs_dense_matmul_benchmark_sparse(
x_ind, x_val, x_shape, y_t, adjoint_a, adjoint_b)
else:
- x_ind = constant_op.constant(np.vstack(np.where(x)).astype(np.int64).T)
- x_val = constant_op.constant(x[np.where(x)])
- x_shape = constant_op.constant(np.array(x.shape).astype(np.int64))
- y_t = constant_op.constant(y)
- ops_fn = _sparse_tensor_dense_vs_dense_matmul_benchmark_sparse(
- x_ind, x_val, x_shape, y_t, adjoint_a, adjoint_b)
- delta_sparse = _timer(sess, ops_fn, 1000)
+ with ops.device("/gpu:0"):
+ x_ind = constant_op.constant(np.vstack(np.where(x)).astype(np.int64).T)
+ x_val = constant_op.constant(x[np.where(x)])
+ x_shape = constant_op.constant(np.array(x.shape).astype(np.int64))
+ y_t = constant_op.constant(y)
+ ops_fn = _sparse_tensor_dense_vs_dense_matmul_benchmark_sparse(
+ x_ind, x_val, x_shape, y_t, adjoint_a, adjoint_b)
+ delta_sparse = _timer(sess, ops_fn, 200)
print("%g \t %d \t %s \t %d \t %d \t %g \t %g \t %g" %
(1 - thresh, n, use_gpu, m, k, delta_dense, delta_sparse,
@@ -340,7 +389,7 @@ def main(_):
"\t dt(sparse)/dt(dense)")
for thresh in (0.99, 0.8, 0.5, 0.2):
- for n in (1, 10, 25):
+ for n in (50, 100):
for use_gpu in (True, False):
for m in (100, 1000):
for k in (100, 1000):