aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/cmake/CMakeLists.txt3
-rw-r--r--tensorflow/contrib/cmake/external/cub.cmake29
-rw-r--r--tensorflow/contrib/cmake/patches/cub/CMakeLists.txt3
-rw-r--r--tensorflow/core/framework/register_types.h5
-rw-r--r--tensorflow/core/kernels/BUILD5
-rw-r--r--tensorflow/core/kernels/dense_update_ops.cc2
-rw-r--r--tensorflow/core/kernels/dense_update_ops_gpu.cu.cc8
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.cc4
-rw-r--r--tensorflow/core/kernels/where_op.cc259
-rw-r--r--tensorflow/core/kernels/where_op.h61
-rw-r--r--tensorflow/core/kernels/where_op_gpu.cu.cc253
-rw-r--r--tensorflow/python/kernel_tests/BUILD2
-rw-r--r--tensorflow/python/kernel_tests/where_op_test.py59
-rw-r--r--tensorflow/tools/lib_package/BUILD2
-rw-r--r--tensorflow/tools/pip_package/BUILD1
-rw-r--r--tensorflow/workspace.bzl16
-rw-r--r--third_party/cub.BUILD26
17 files changed, 665 insertions, 73 deletions
diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt
index fe66f8c49a..655e0b29a0 100644
--- a/tensorflow/contrib/cmake/CMakeLists.txt
+++ b/tensorflow/contrib/cmake/CMakeLists.txt
@@ -122,6 +122,7 @@ include(fft2d)
include(highwayhash)
include(protobuf)
include(re2)
+include(cub)
if (tensorflow_BUILD_CC_TESTS)
include(googletest)
endif()
@@ -151,6 +152,7 @@ set(tensorflow_EXTERNAL_DEPENDENCIES
protobuf
eigen
gemmlowp
+ cub
fft2d
re2
)
@@ -170,6 +172,7 @@ include_directories(
${jsoncpp_INCLUDE_DIR}
${farmhash_INCLUDE_DIR}
${highwayhash_INCLUDE_DIR}
+ ${cub_INCLUDE_DIR}
${PROTOBUF_INCLUDE_DIRS}
${re2_INCLUDE_DIR}
)
diff --git a/tensorflow/contrib/cmake/external/cub.cmake b/tensorflow/contrib/cmake/external/cub.cmake
new file mode 100644
index 0000000000..54329dbd70
--- /dev/null
+++ b/tensorflow/contrib/cmake/external/cub.cmake
@@ -0,0 +1,29 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+include (ExternalProject)
+
+set(cub_URL https://github.com/NVlabs/cub/archive/1.6.4.zip)
+set(cub_HASH SHA256=966d0c4f41e2bdc81aebf9ccfbf0baffaac5a74f00b826b06f4dee79b2bb8cee)
+set(cub_BUILD ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub)
+set(cub_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub)
+
+ExternalProject_Add(cub
+ PREFIX cub
+ URL ${cub_URL}
+ URL_HASH ${cub_HASH}
+ DOWNLOAD_DIR "${DOWNLOAD_LOCATION}"
+ BUILD_IN_SOURCE 1
+ PATCH_COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_CURRENT_SOURCE_DIR}/patches/cub/CMakeLists.txt ${cub_BUILD}
+ INSTALL_COMMAND "")
diff --git a/tensorflow/contrib/cmake/patches/cub/CMakeLists.txt b/tensorflow/contrib/cmake/patches/cub/CMakeLists.txt
new file mode 100644
index 0000000000..36890f0ce6
--- /dev/null
+++ b/tensorflow/contrib/cmake/patches/cub/CMakeLists.txt
@@ -0,0 +1,3 @@
+cmake_minimum_required(VERSION 2.8.3)
+
+project(cub)
diff --git a/tensorflow/core/framework/register_types.h b/tensorflow/core/framework/register_types.h
index 2f7b140295..b62fe647e2 100644
--- a/tensorflow/core/framework/register_types.h
+++ b/tensorflow/core/framework/register_types.h
@@ -167,10 +167,13 @@ limitations under the License.
// Call "m" on POD and string types.
#define TF_CALL_POD_STRING_TYPES(m) TF_CALL_POD_TYPES(m) TF_CALL_string(m)
-// Call "m" on all types supported on GPU.
+// Call "m" on all number types supported on GPU.
#define TF_CALL_GPU_NUMBER_TYPES(m) \
TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m)
+// Call "m" on all types supported on GPU.
+#define TF_CALL_GPU_ALL_TYPES(m) TF_CALL_GPU_NUMBER_TYPES(m) TF_CALL_bool(m)
+
#define TF_CALL_GPU_NUMBER_TYPES_NO_HALF(m) TF_CALL_float(m) TF_CALL_double(m)
// Call "m" on all quantized types.
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index f0dc3312c2..568e9b717d 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -763,7 +763,10 @@ tf_kernel_library(
tf_kernel_library(
name = "where_op",
prefix = "where_op",
- deps = ARRAY_DEPS,
+ deps = if_cuda([
+ ":cuda_solvers",
+ "@cub_archive//:cub",
+ ]) + ARRAY_DEPS,
)
tf_cc_test(
diff --git a/tensorflow/core/kernels/dense_update_ops.cc b/tensorflow/core/kernels/dense_update_ops.cc
index ef34946d96..af16c55a5d 100644
--- a/tensorflow/core/kernels/dense_update_ops.cc
+++ b/tensorflow/core/kernels/dense_update_ops.cc
@@ -154,7 +154,7 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
Name("Assign").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
AssignOpT<GPUDevice, type>);
-TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
+TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
#undef REGISTER_GPU_KERNELS
#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/dense_update_ops_gpu.cu.cc b/tensorflow/core/kernels/dense_update_ops_gpu.cu.cc
index 0f61506afe..cf2d493061 100644
--- a/tensorflow/core/kernels/dense_update_ops_gpu.cu.cc
+++ b/tensorflow/core/kernels/dense_update_ops_gpu.cu.cc
@@ -27,11 +27,15 @@ typedef Eigen::GpuDevice GPUDevice;
#define DEFINE_GPU_KERNELS(T) \
template struct functor::DenseUpdate<GPUDevice, T, ADD>; \
- template struct functor::DenseUpdate<GPUDevice, T, SUB>; \
- template struct functor::DenseUpdate<GPUDevice, T, ASSIGN>;
+ template struct functor::DenseUpdate<GPUDevice, T, SUB>;
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
#undef DEFINE_GPU_KERNELS
+#define DEFINE_GPU_KERNELS(T) \
+ template struct functor::DenseUpdate<GPUDevice, T, ASSIGN>;
+TF_CALL_GPU_ALL_TYPES(DEFINE_GPU_KERNELS);
+#undef DEFINE_GPU_KERNELS
+
} // end namespace tensorflow
#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index db6d8c20ec..e108a59275 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -94,7 +94,7 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
.HostMemory("resource"), \
ReadVariableOp<GPUDevice, type>);
-TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
+TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
#undef REGISTER_GPU_KERNELS
#endif // GOOGLE_CUDA
@@ -230,7 +230,7 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
.HostMemory("resource"), \
AssignVariableOp<GPUDevice, type>);
-TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
+TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
#undef REGISTER_GPU_KERNELS
#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/where_op.cc b/tensorflow/core/kernels/where_op.cc
index e56a498845..6fdcb331cc 100644
--- a/tensorflow/core/kernels/where_op.cc
+++ b/tensorflow/core/kernels/where_op.cc
@@ -17,9 +17,14 @@ limitations under the License.
#define EIGEN_USE_THREADS
+#if GOOGLE_CUDA
+#define EIGEN_USE_GPU
+#endif // GOOGLE_CUDA
+
#include "tensorflow/core/kernels/where_op.h"
#include <memory>
+#include <numeric>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -27,43 +32,113 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
+#if GOOGLE_CUDA
+#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
+#include "tensorflow/core/kernels/cuda_solvers.h"
+#endif // GOOGLE_CUDA
+
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-template <typename Device>
-class WhereOp : public OpKernel {
+namespace functor {
+
+template <>
+struct NumTrue<CPUDevice, int64> {
+ static Status Compute(OpKernelContext* ctx, const CPUDevice& d,
+ TTypes<bool>::ConstFlat input,
+ TTypes<int64>::Scalar num_true) {
+ *num_true.data() =
+ std::accumulate(input.data(), input.data() + input.size(), 0);
+ return Status::OK();
+ }
+};
+
+template <int DIMS, typename TIndex>
+struct Where<CPUDevice, DIMS, TIndex> {
+ EIGEN_ALWAYS_INLINE static void WriteIndexRowMajor(
+ typename TTypes<int64>::Matrix output,
+ const typename Eigen::DSizes<TIndex, DIMS>& strides, TIndex true_n,
+ TIndex index) {
+ for (int i = 0; i < DIMS; ++i) {
+ output(true_n, i) = index / strides[i];
+ index -= output(true_n, i) * strides[i];
+ }
+ }
+
+ EIGEN_ALWAYS_INLINE static Status Compute(
+ OpKernelContext* ctx, const CPUDevice& d,
+ typename TTypes<bool, DIMS>::ConstTensor input,
+ typename TTypes<int64>::Matrix output, TIndex* found_true) {
+ Eigen::DSizes<Eigen::DenseIndex, DIMS> dims = input.dimensions();
+ Eigen::DSizes<TIndex, DIMS> strides;
+
+ EIGEN_STATIC_ASSERT((static_cast<int>(decltype(input)::Layout) ==
+ static_cast<int>(Eigen::RowMajor)),
+ INTERNAL_ERROR_INPUT_SHOULD_BE_ROWMAJOR);
+
+ strides[DIMS - 1] = 1;
+ for (int i = DIMS - 2; i >= 0; --i) {
+ strides[i] = strides[i + 1] * dims[i + 1];
+ }
+
+ Eigen::DenseIndex output_size = output.dimension(0);
+ for (Eigen::DenseIndex n = 0; n < input.size(); ++n) {
+ if (input.data()[n]) {
+ if (FastBoundsCheck(*found_true, output_size)) {
+ WriteIndexRowMajor(output, strides, *found_true, n);
+ }
+ ++*found_true;
+ }
+ }
+ return Status::OK();
+ }
+};
+
+} // namespace functor
+
+class WhereCPUOp : public OpKernel {
public:
- explicit WhereOp(OpKernelConstruction* context) : OpKernel(context) {}
+ explicit WhereCPUOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
const int input_dims = input.dims();
+
Tensor num_true;
OP_REQUIRES_OK(
context, context->allocate_temp(DT_INT64, TensorShape({}), &num_true));
auto num_true_t = num_true.scalar<int64>();
- functor::NumTrue<Device>::Compute(context->eigen_device<Device>(),
- input.flat<bool>(), num_true_t);
+ Status s = functor::NumTrue<CPUDevice, int64>::Compute(
+ context, context->eigen_device<CPUDevice>(), input.flat<bool>(),
+ num_true_t);
+ OP_REQUIRES_OK(context, s);
TensorShape output_shape({num_true_t(), input_dims});
Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
-#define HANDLE_DIM(NDIM) \
- case NDIM: \
- found_true = functor::Where<Device, NDIM>::Compute( \
- context->eigen_device<Device>(), input.tensor<bool, NDIM>(), \
- output->matrix<int64>()); \
- break;
-
+ // TODO(ebrevdo): Replace single-threaded copy with a
+ // multithreaded block copy by getting block counts above instead
+ // of a global NumTrue, then having each block filled in in
+ // separate threads below.
int64 found_true = 0;
+
+#define HANDLE_DIM(NDIM) \
+ case NDIM: { \
+ Status s = functor::Where<CPUDevice, NDIM, int64>::Compute( \
+ context, context->eigen_device<CPUDevice>(), \
+ input.tensor<bool, NDIM>(), output->matrix<int64>(), &found_true); \
+ OP_REQUIRES_OK(context, s); \
+ } break;
+
switch (input_dims) {
HANDLE_DIM(1);
HANDLE_DIM(2);
@@ -79,7 +154,7 @@ class WhereOp : public OpKernel {
#undef HANDLE_DIM
OP_REQUIRES(
- context, num_true_t() == found_true,
+ context, found_true == num_true_t(),
errors::InvalidArgument(
"WhereOp: Race condition between counting the number of true "
"elements and writing them. When counting, saw ",
@@ -88,12 +163,162 @@ class WhereOp : public OpKernel {
}
private:
- TF_DISALLOW_COPY_AND_ASSIGN(WhereOp);
+ TF_DISALLOW_COPY_AND_ASSIGN(WhereCPUOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("Where").Device(DEVICE_CPU), WhereCPUOp);
+
+#if GOOGLE_CUDA
+
+namespace functor {
+
+#define DECLARE_GPU_NUMTRUE(Tindex) \
+ template <> \
+ Status NumTrue<GPUDevice, Tindex>::Compute( \
+ OpKernelContext* ctx, const GPUDevice& d, TTypes<bool>::ConstFlat input, \
+ TTypes<Tindex>::Scalar num_true); \
+ extern template struct NumTrue<GPUDevice, Tindex>
+
+DECLARE_GPU_NUMTRUE(int32);
+DECLARE_GPU_NUMTRUE(int64);
+#undef DECLARE_GPU_NUMTRUE
+
+#define DECLARE_GPU_WHERE_INDEX(Dims, Tindex) \
+ template <> \
+ Status Where<GPUDevice, Dims, Tindex>::Compute( \
+ OpKernelContext* ctx, const GPUDevice& d, \
+ typename TTypes<bool, Dims>::ConstTensor input, \
+ typename TTypes<int64>::Matrix output, Tindex* found_true); \
+ extern template struct Where<GPUDevice, Dims, Tindex>;
+#define DECLARE_GPU_WHERE(Dims) \
+ DECLARE_GPU_WHERE_INDEX(Dims, int32); \
+ DECLARE_GPU_WHERE_INDEX(Dims, int64);
+
+DECLARE_GPU_WHERE(1);
+DECLARE_GPU_WHERE(2);
+DECLARE_GPU_WHERE(3);
+DECLARE_GPU_WHERE(4);
+DECLARE_GPU_WHERE(5);
+#undef DECLARE_GPU_WHERE
+#undef DECLARE_GPU_WHERE_INDEX
+
+} // namespace functor
+
+class WhereGPUOp : public AsyncOpKernel {
+ public:
+ explicit WhereGPUOp(OpKernelConstruction* context) : AsyncOpKernel(context) {}
+
+ void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
+ const Tensor& input = context->input(0);
+ const int input_dims = input.dims();
+
+ if (input.NumElements() < std::numeric_limits<int32>::max()) {
+ ComputeAsyncType<int32>(input, input_dims, context, done);
+ } else {
+ ComputeAsyncType<int64>(input, input_dims, context, done);
+ }
+ }
+
+ template <typename Tindex>
+ void ComputeAsyncType(const Tensor& input, const int input_dims,
+ OpKernelContext* context, DoneCallback done) {
+ // Step 0: alloc nnz
+ // Step 1: call nnz kernel
+ // Step 2: copy nnz to host
+ // Step 3: call create_output
+ // Step 4: call where kernel
+ Tensor num_true;
+ OP_REQUIRES_OK_ASYNC(context,
+ context->allocate_temp(DataTypeToEnum<Tindex>::v(),
+ TensorShape({}), &num_true),
+ done);
+
+ auto num_true_t = num_true.scalar<Tindex>();
+
+ perftools::gputools::DeviceMemoryBase num_true_ptr(
+ static_cast<void*>(num_true_t.data()));
+ // Push kernel to stream to get number of true elements.
+ const GPUDevice& d = context->eigen_device<GPUDevice>();
+ Status s = functor::NumTrue<GPUDevice, Tindex>::Compute(
+ context, d, input.flat<bool>(), num_true_t);
+ OP_REQUIRES_OK_ASYNC(context, s, done);
+
+ // Copy num_true to host;
+ ScratchSpace<Tindex> num_true_host(context, 1, /* on_host */ true);
+
+ auto stream = context->op_device_context()->stream();
+ OP_REQUIRES_ASYNC(
+ context,
+ stream
+ ->ThenMemcpy(num_true_host.mutable_data(), num_true_ptr,
+ sizeof(Tindex))
+ .ok(),
+ errors::Internal("WhereOp: failed to copy num_true from device"), done);
+
+ auto create_and_check_output = [context, &d, &input, input_dims,
+ num_true_host, done]() {
+
+ Tindex num_true = *num_true_host.data();
+
+ // TODO(ebrevdo): Properly copy back found_true value to CPU for
+ // validation checking. Currently Where<GPUDevice>::Compute()
+ // does not perform this copy back to CPU.
+ Tindex found_true = -1;
+
+ // Step 1: Allocate the output and perform the selection/copy.
+ Tensor* output;
+ OP_REQUIRES_OK_ASYNC(context,
+ context->allocate_output(
+ 0, TensorShape({num_true, input_dims}), &output),
+ done);
+
+#define HANDLE_DIM(NDIM) \
+ case NDIM: { \
+ Status s = functor::Where<GPUDevice, NDIM, Tindex>::Compute( \
+ context, d, input.tensor<bool, NDIM>(), output->matrix<int64>(), \
+ &found_true); \
+ OP_REQUIRES_OK_ASYNC(context, s, done); \
+ } break;
+
+ switch (input_dims) {
+ HANDLE_DIM(1);
+ HANDLE_DIM(2);
+ HANDLE_DIM(3);
+ HANDLE_DIM(4);
+ HANDLE_DIM(5);
+
+ default:
+ OP_REQUIRES_ASYNC(
+ context, false,
+ errors::InvalidArgument("WhereOp: Unhandled input dimensions: ",
+ input_dims),
+ done);
+ }
+#undef HANDLE_DIM
+
+ // TODO(ebrevdo): Fix the copy back to host.
+
+ // OP_REQUIRES_ASYNC(
+ // context, found_true == num_true,
+ // errors::InvalidArgument(
+ // "WhereOp: Race condition between counting the number of true "
+ // "elements and writing them. When counting, saw ",
+ // num_true, " elements; but when writing their indices, saw ",
+ // found_true, " elements."),
+ // done);
+
+ done();
+ };
+ context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
+ stream, create_and_check_output);
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(WhereGPUOp);
};
-#define REGISTER_WHERE() \
- REGISTER_KERNEL_BUILDER(Name("Where").Device(DEVICE_CPU), WhereOp<CPUDevice>);
+REGISTER_KERNEL_BUILDER(Name("Where").Device(DEVICE_GPU), WhereGPUOp);
-REGISTER_WHERE();
+#endif // GOOGLE_CUDA
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/where_op.h b/tensorflow/core/kernels/where_op.h
index aa27123714..e040325e3d 100644
--- a/tensorflow/core/kernels/where_op.h
+++ b/tensorflow/core/kernels/where_op.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_KERNELS_WHERE_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -25,55 +26,25 @@ namespace tensorflow {
namespace functor {
-template <typename Device>
+template <typename Device, typename TIndex>
struct NumTrue {
- EIGEN_ALWAYS_INLINE static void Compute(
- const Device& d, typename TTypes<bool>::ConstFlat input,
- TTypes<int64>::Scalar num_true) {
- num_true.device(d) = input.template cast<int64>().sum();
- }
+ EIGEN_ALWAYS_INLINE static Status Compute(
+ OpKernelContext* ctx, const Device& d, TTypes<bool>::ConstFlat input,
+ typename TTypes<TIndex>::Scalar num_true);
};
-template <typename Device, int NDIM>
+template <typename Device, int NDIM, typename TIndex>
struct Where {
- EIGEN_ALWAYS_INLINE static int64 Compute(
- const Device& d, typename TTypes<bool, NDIM>::ConstTensor input,
- typename TTypes<int64>::Matrix output) {
- Eigen::DenseIndex true_n = 0;
- Eigen::DSizes<Eigen::DenseIndex, NDIM> dims = input.dimensions();
- Eigen::DSizes<Eigen::DenseIndex, NDIM> strides;
-
- // Calculate strides for RowMajor order.
- EIGEN_STATIC_ASSERT((static_cast<int>(decltype(input)::Layout) ==
- static_cast<int>(Eigen::RowMajor)),
- INTERNAL_ERROR_INPUT_SHOULD_BE_ROWMAJOR);
-
- strides[NDIM - 1] = 1;
- for (int i = NDIM - 2; i >= 0; --i) {
- strides[i] = strides[i + 1] * dims[i + 1];
- }
-
- Eigen::DenseIndex output_size = output.dimension(0);
- for (Eigen::DenseIndex n = 0; n < input.size(); ++n) {
- if (input.data()[n]) {
- if (TF_PREDICT_TRUE(true_n < output_size)) {
- WriteIndexRowMajor(output, strides, true_n, n);
- }
- ++true_n;
- }
- }
- return true_n;
- }
-
- EIGEN_ALWAYS_INLINE static void WriteIndexRowMajor(
- typename TTypes<int64>::Matrix output,
- const Eigen::DSizes<Eigen::DenseIndex, NDIM>& strides,
- Eigen::DenseIndex true_n, Eigen::DenseIndex index) {
- for (int i = 0; i < NDIM; ++i) {
- output(true_n, i) = index / strides[i];
- index %= strides[i];
- }
- }
+ // Copies indices of true values in input into output. The pointer
+ // found_true should sit on the host. Compute should copy the
+ // number of true elements found into it. At the end, if
+ // *found_true != output.dimension(0),
+ // then the input may have changed between the initial counting of
+ // the true values and the call to Where.
+ EIGEN_ALWAYS_INLINE static Status Compute(
+ OpKernelContext* ctx, const Device& d,
+ typename TTypes<bool, NDIM>::ConstTensor input,
+ typename TTypes<int64>::Matrix output, TIndex* found_true);
};
} // namespace functor
diff --git a/tensorflow/core/kernels/where_op_gpu.cu.cc b/tensorflow/core/kernels/where_op_gpu.cu.cc
new file mode 100644
index 0000000000..09e8be58fa
--- /dev/null
+++ b/tensorflow/core/kernels/where_op_gpu.cu.cc
@@ -0,0 +1,253 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "external/cub_archive/cub/device/device_reduce.cuh"
+#include "external/cub_archive/cub/device/device_select.cuh"
+#include "external/cub_archive/cub/iterator/counting_input_iterator.cuh"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/where_op.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/cuda_kernel_helper.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+namespace functor {
+
+template <int NDIM, typename TIndex>
+__global__ void PropagateWhereIndicesKernel(
+ const TIndex output_rows, const typename Eigen::array<TIndex, NDIM> strides,
+ int64* output) {
+ // TODO(ebrevdo): Use a multi-dimensional loop, increasing the
+ // dimensions of individual indices manually, instead of relying on
+ // a scalar loop variable and using integer division.
+ CUDA_1D_KERNEL_LOOP(i, output_rows) {
+ TIndex index_value = ldg(output + NDIM * i);
+#pragma unroll
+ for (int c = 0; c < NDIM; ++c) {
+ *(output + NDIM * i + c) = index_value / strides[c];
+ index_value %= strides[c];
+ }
+ }
+}
+
+template <typename TIndex>
+struct NumTrue<GPUDevice, TIndex> {
+ EIGEN_ALWAYS_INLINE static Status Compute(
+ OpKernelContext* ctx, const GPUDevice& d, TTypes<bool>::ConstFlat input,
+ typename TTypes<TIndex>::Scalar num_true) {
+ std::size_t temp_storage_bytes = 0;
+
+ const bool* input_data = input.data();
+ TIndex* num_true_data = num_true.data();
+
+ auto first_success =
+ cub::DeviceReduce::Sum(/*temp_storage*/ nullptr, temp_storage_bytes,
+ /*d_in*/ input_data,
+ /*d_out*/ num_true_data,
+ /*num_items*/ input.size(),
+ /*stream*/ d.stream());
+
+ if (first_success != cudaSuccess) {
+ return errors::Internal(
+ "WhereOp: Could not launch cub::DeviceReduce::Sum to calculate "
+ "temp_storage_bytes, status: ",
+ cudaGetErrorString(first_success));
+ }
+
+ Tensor temp_storage;
+ TF_RETURN_IF_ERROR(ctx->allocate_temp(
+ DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
+ &temp_storage));
+
+ auto second_success = cub::DeviceReduce::Sum(
+ /*temp_storage*/ temp_storage.flat<int8>().data(), temp_storage_bytes,
+ /*d_in*/ input_data,
+ /*d_out*/ num_true_data,
+ /*num_items*/ input.size(),
+ /*stream*/ d.stream());
+
+ if (second_success != cudaSuccess) {
+ return errors::Internal(
+ "WhereOp: Could not launch cub::DeviceReduce::Sum to count "
+ "number of true indices, status: ",
+ cudaGetErrorString(second_success));
+ }
+
+ return Status::OK();
+ }
+};
+
+template struct NumTrue<GPUDevice, int32>;
+template struct NumTrue<GPUDevice, int64>;
+
+template <int NDIM>
+class WhereOutputIterator {
+ public:
+ // Required iterator traits
+ typedef WhereOutputIterator self_type;
+ typedef std::ptrdiff_t difference_type;
+ typedef void value_type;
+ typedef void pointer;
+ typedef int64& reference;
+
+#if (THRUST_VERSION >= 100700)
+ // Use Thrust's iterator categories so we can use these iterators in Thrust
+ // 1.7 (or newer) methods
+ typedef typename thrust::detail::iterator_facade_category<
+ thrust::device_system_tag, thrust::random_access_traversal_tag,
+ value_type,
+ reference>::type iterator_category; ///< The iterator category
+#else
+ typedef std::random_access_iterator_tag
+ iterator_category; ///< The iterator category
+#endif // THRUST_VERSION
+
+ WhereOutputIterator(int64* ptr, const Eigen::DenseIndex max_row)
+ : ptr_(ptr), max_row_(max_row) {}
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int64& operator[](int n) const {
+ // If the selection mechanism finds too many true values (because
+ // the input tensor changed between allocation of output and now),
+ // we may accidentally try to write past the allowable memory. If
+ // valid is false, then we don't do this. Instead, we'll read off
+ // the number of items found in Flagged()'s d_num_selected_out at
+ // the end and confirm that it matches the number of rows of output.
+ const bool valid = FastBoundsCheck(n, max_row_);
+ return *(ptr_ + (valid ? (NDIM * n) : 0));
+ }
+
+ private:
+ int64* ptr_;
+ const Eigen::DenseIndex max_row_;
+};
+
+template <typename TIndex, int NDIM>
+Eigen::array<TIndex, NDIM> CalculateStrides(
+ typename TTypes<bool, NDIM>::ConstTensor input) {
+ const Eigen::DSizes<Eigen::DenseIndex, NDIM> dims = input.dimensions();
+ Eigen::array<TIndex, NDIM> strides;
+ EIGEN_STATIC_ASSERT((static_cast<int>(decltype(input)::Layout) ==
+ static_cast<int>(Eigen::RowMajor)),
+ INTERNAL_ERROR_INPUT_SHOULD_BE_ROWMAJOR);
+ strides[NDIM - 1] = 1;
+ for (int i = NDIM - 2; i >= 0; --i) {
+ strides[i] = strides[i + 1] * dims[i + 1];
+ }
+ return strides;
+}
+
+template <int NDIM, typename Tindex>
+struct Where<GPUDevice, NDIM, Tindex> {
+ EIGEN_ALWAYS_INLINE static Status Compute(
+ OpKernelContext* ctx, const GPUDevice& d,
+ typename TTypes<bool, NDIM>::ConstTensor input,
+ typename TTypes<int64>::Matrix output, Tindex* found_true_host) {
+ if (output.dimension(0) == 0) {
+ // Nothing to do.
+ return Status::OK();
+ }
+
+ std::size_t temp_storage_bytes = 0;
+
+ cub::CountingInputIterator<Tindex> select_counter(0);
+
+ Tensor found_true_t;
+ TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<Tindex>::v(),
+ TensorShape({}), &found_true_t));
+ Tindex* found_true_device = found_true_t.scalar<Tindex>().data();
+
+ WhereOutputIterator<NDIM> output_iterator(
+ output.data(),
+ /* max_row */ output.dimension(0));
+
+ auto first_success =
+ cub::DeviceSelect::Flagged(/*temp_storage*/ nullptr, temp_storage_bytes,
+ /*d_in*/ select_counter,
+ /*d_flags*/ input.data(),
+ /*d_out*/ output_iterator,
+ /*d_num_selected_out*/ found_true_device,
+ /*num_items*/ input.size(),
+ /*stream*/ d.stream());
+ if (first_success != cudaSuccess) {
+ return errors::Internal(
+ "WhereOp: Could not launch cub::DeviceSelect::Flagged to calculate "
+ "temp_storage_bytes, status: ",
+ cudaGetErrorString(first_success));
+ }
+
+ Tensor temp_storage;
+ TF_RETURN_IF_ERROR(ctx->allocate_temp(
+ DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
+ &temp_storage));
+
+ auto second_success = cub::DeviceSelect::Flagged(
+ /*temp_storage*/ temp_storage.flat<int8>().data(), temp_storage_bytes,
+ /*d_in*/ select_counter,
+ /*d_flags*/ input.data(),
+ /*d_out*/ output_iterator,
+ /*d_num_selected_out*/ found_true_device,
+ /*num_items*/ input.size(),
+ /*stream*/ d.stream());
+
+ if (second_success != cudaSuccess) {
+ return errors::Internal(
+ "WhereOp: Could not launch cub::DeviceSelect::Flagged to copy "
+ "indices out, status: ",
+ cudaGetErrorString(second_success));
+ }
+
+ // TODO(ebrevdo): Find a way to synchronously copy back data from
+ // found_true_device to *found_true_host.
+
+ const Eigen::array<Tindex, NDIM> strides =
+ CalculateStrides<Tindex, NDIM>(input);
+ const Tindex output_rows = output.dimension(0);
+ CudaLaunchConfig config = GetCudaLaunchConfig(output_rows, d);
+ PropagateWhereIndicesKernel<NDIM, Tindex>
+ <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
+ output_rows, strides, output.data());
+
+ return Status::OK();
+ }
+};
+
+#define DECLARE_GPU_SPEC_INDEX(Dims, Tindex) \
+ template struct Where<GPUDevice, Dims, Tindex>
+#define DECLARE_GPU_SPEC(Dims) \
+ DECLARE_GPU_SPEC_INDEX(Dims, int32); \
+ DECLARE_GPU_SPEC_INDEX(Dims, int64)
+
+DECLARE_GPU_SPEC(1);
+DECLARE_GPU_SPEC(2);
+DECLARE_GPU_SPEC(3);
+DECLARE_GPU_SPEC(4);
+DECLARE_GPU_SPEC(5);
+
+#undef DECLARE_GPU_SPEC
+#undef DECLARE_GPU_SPEC_INDEX
+
+} // namespace functor
+
+} // namespace tensorflow
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index d8e069105f..64d0b8fa52 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -900,7 +900,7 @@ tf_py_test(
],
)
-tf_py_test(
+cuda_py_test(
name = "where_op_test",
size = "small",
srcs = ["where_op_test.py"],
diff --git a/tensorflow/python/kernel_tests/where_op_test.py b/tensorflow/python/kernel_tests/where_op_test.py
index b47159ae7f..a428d26996 100644
--- a/tensorflow/python/kernel_tests/where_op_test.py
+++ b/tensorflow/python/kernel_tests/where_op_test.py
@@ -18,17 +18,25 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import itertools
+import sys
+
import numpy as np
+from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
class WhereOpTest(test.TestCase):
def _testWhere(self, x, truth, expected_err_re=None):
- with self.test_session():
+ with self.test_session(use_gpu=True):
ans = array_ops.where(x)
self.assertEqual([None, x.ndim], ans.get_shape().as_list())
if expected_err_re is None:
@@ -39,12 +47,30 @@ class WhereOpTest(test.TestCase):
ans.eval()
def testWrongNumbers(self):
- with self.test_session():
+ with self.test_session(use_gpu=True):
with self.assertRaises(ValueError):
array_ops.where([False, True], [1, 2], None)
with self.assertRaises(ValueError):
array_ops.where([False, True], None, [1, 2])
+ def testBasicVec(self):
+ x = np.asarray([True, False])
+ truth = np.asarray([[0]], dtype=np.int64)
+ self._testWhere(x, truth)
+
+ x = np.asarray([False, True, False])
+ truth = np.asarray([[1]], dtype=np.int64)
+ self._testWhere(x, truth)
+
+ x = np.asarray([False, False, True, False, True])
+ truth = np.asarray([[2], [4]], dtype=np.int64)
+ self._testWhere(x, truth)
+
+ def testRandomVec(self):
+ x = np.random.rand(1000000) > 0.5
+ truth = np.vstack([np.where(x)[0].astype(np.int64)]).T
+ self._testWhere(x, truth)
+
def testBasicMat(self):
x = np.asarray([[True, False], [True, False]])
@@ -67,10 +93,37 @@ class WhereOpTest(test.TestCase):
def testThreeArgument(self):
x = np.array([[-2, 3, -1], [1, -3, -3]])
np_val = np.where(x > 0, x * x, -x)
- with self.test_session():
+ with self.test_session(use_gpu=True):
tf_val = array_ops.where(constant_op.constant(x) > 0, x * x, -x).eval()
self.assertAllEqual(tf_val, np_val)
+class WhereBenchmark(test.Benchmark):
+
+ def benchmarkWhereCPU(self):
+ for (m, n, p, use_gpu) in itertools.product(
+ [10],
+ [10, 100, 1000, 10000, 100000, 1000000],
+ [0.01, 0.5, 0.99],
+ [False, True]):
+ name = "m_%d_n_%d_p_%g_use_gpu_%s" % (m, n, p, use_gpu)
+ device = "/%s:0" % ("gpu" if use_gpu else "cpu")
+ with ops.Graph().as_default():
+ with ops.device(device):
+ x = random_ops.random_uniform((m, n), dtype=dtypes.float32) <= p
+ v = resource_variable_ops.ResourceVariable(x)
+ op = array_ops.where(v)
+ with session.Session() as sess:
+ v.initializer.run()
+ r = self.run_op_benchmark(sess, op, min_iters=100, name=name)
+ gb_processed_input = m * n / 1.0e9
+ # approximate size of output: m*n*p int64s for each axis.
+ gb_processed_output = 2 * 8 * m * n * p / 1.0e9
+ gb_processed = gb_processed_input + gb_processed_output
+ throughput = gb_processed / r["wall_time"]
+ print("Benchmark: %s \t wall_time: %0.03g s \t "
+ "Throughput: %0.03g GB/s" % (name, r["wall_time"], throughput))
+ sys.stdout.flush()
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/tools/lib_package/BUILD b/tensorflow/tools/lib_package/BUILD
index 51ba3b7a0b..9da5d5cb5b 100644
--- a/tensorflow/tools/lib_package/BUILD
+++ b/tensorflow/tools/lib_package/BUILD
@@ -87,6 +87,7 @@ genrule(
"//third_party/fft2d:LICENSE",
"@boringssl//:LICENSE",
"@com_googlesource_code_re2//:LICENSE",
+ "@cub_archive//:LICENSE.TXT",
"@curl//:COPYING",
"@eigen_archive//:COPYING.MPL2",
"@farmhash_archive//:COPYING",
@@ -117,6 +118,7 @@ genrule(
"//third_party/fft2d:LICENSE",
"@boringssl//:LICENSE",
"@com_googlesource_code_re2//:LICENSE",
+ "@cub_archive//:LICENSE.TXT",
"@curl//:COPYING",
"@eigen_archive//:COPYING.MPL2",
"@farmhash_archive//:COPYING",
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index 6a6ce5b9e3..0bd7cda1c5 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -95,6 +95,7 @@ filegroup(
"//third_party/hadoop:LICENSE.txt",
"@boringssl//:LICENSE",
"@com_googlesource_code_re2//:LICENSE",
+ "@cub_archive//:LICENSE.TXT",
"@curl//:COPYING",
"@eigen_archive//:COPYING.MPL2",
"@farmhash_archive//:COPYING",
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index a5172187f7..a6c5117e22 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -643,3 +643,19 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
strip_prefix = "pprof-c0fb62ec88c411cc91194465e54db2632845b650",
build_file = str(Label("//third_party:pprof.BUILD")),
)
+
+ native.new_http_archive(
+ name = "cub_archive",
+ urls = [
+ "http://mirror.bazel.build/github.com/NVlabs/cub/archive/1.6.4.zip",
+ "https://github.com/NVlabs/cub/archive/1.6.4.zip",
+ ],
+ sha256 = "966d0c4f41e2bdc81aebf9ccfbf0baffaac5a74f00b826b06f4dee79b2bb8cee",
+ strip_prefix = "cub-1.6.4",
+ build_file = str(Label("//third_party:cub.BUILD")),
+ )
+
+ native.bind(
+ name = "cub",
+ actual = "@cub_archive//:cub",
+ )
diff --git a/third_party/cub.BUILD b/third_party/cub.BUILD
new file mode 100644
index 0000000000..29159c9dad
--- /dev/null
+++ b/third_party/cub.BUILD
@@ -0,0 +1,26 @@
+# Description: CUB library which is a set of primitives for GPU programming.
+
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"]) # BSD
+
+exports_files(["LICENSE.TXT"])
+
+load("@local_config_cuda//cuda:build_defs.bzl", "cuda_default_copts", "if_cuda")
+
+filegroup(
+ name = "cub_header_files",
+ srcs = glob([
+ "cub/**",
+ ]),
+)
+
+cc_library(
+ name = "cub",
+ hdrs = if_cuda([":cub_header_files"]),
+ deps = [
+ "@local_config_cuda//cuda:cuda_headers",
+ ],
+)