diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2017-06-29 15:33:13 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-29 15:37:15 -0700 |
commit | 8280e0ae9083a65b23608b34723f07e028a56dc8 (patch) | |
tree | 0f2df282cfd5cd712920e440cea88a093668cbf2 | |
parent | 4aa7c4d2330ce110b5be348144ee67143841272c (diff) |
GPU-enabled WhereOp using CUB.
* Import CUB.
* Add GPU-enabled async WhereOp.
* Added benchmarks.
* Added support for bool ResourceVariables on GPU.
Benchmark results on machine with single K40 tesla GPU:
Where on bool matrix shape [m x n] with p percentage values true below.
For small-medium sizes, running WhereOp on GPU is ~4-2x slower. For
realistic large problem sizes, it's 2-5x faster. This timing ignores
the time spent copying a tensor from GPU -> CPU and back from CPU -> GPU
when the WhereOp is between GPU computations (so the performance impact
should actually be better).
Benchmark: m_10_n_10_p_0.01_use_gpu_False wall_time: 9.01e-05 s Throughput: 0.00129 GB/s
Benchmark: m_10_n_10_p_0.01_use_gpu_True wall_time: 0.000187 s Throughput: 0.000621 GB/s
Benchmark: m_10_n_10_p_0.5_use_gpu_False wall_time: 9.3e-05 s Throughput: 0.00968 GB/s
Benchmark: m_10_n_10_p_0.5_use_gpu_True wall_time: 0.000252 s Throughput: 0.00357 GB/s
Benchmark: m_10_n_10_p_0.99_use_gpu_False wall_time: 0.000152 s Throughput: 0.0111 GB/s
Benchmark: m_10_n_10_p_0.99_use_gpu_True wall_time: 0.000245 s Throughput: 0.00687 GB/s
Benchmark: m_10_n_100_p_0.01_use_gpu_False wall_time: 9.3e-05 s Throughput: 0.0125 GB/s
Benchmark: m_10_n_100_p_0.01_use_gpu_True wall_time: 0.000253 s Throughput: 0.00458 GB/s
Benchmark: m_10_n_100_p_0.5_use_gpu_False wall_time: 9.8e-05 s Throughput: 0.0918 GB/s
Benchmark: m_10_n_100_p_0.5_use_gpu_True wall_time: 0.00026 s Throughput: 0.0346 GB/s
Benchmark: m_10_n_100_p_0.99_use_gpu_False wall_time: 0.000104 s Throughput: 0.162 GB/s
Benchmark: m_10_n_100_p_0.99_use_gpu_True wall_time: 0.000288 s Throughput: 0.0586 GB/s
Benchmark: m_10_n_1000_p_0.01_use_gpu_False wall_time: 0.000105 s Throughput: 0.111 GB/s
Benchmark: m_10_n_1000_p_0.01_use_gpu_True wall_time: 0.000283 s Throughput: 0.041 GB/s
Benchmark: m_10_n_1000_p_0.5_use_gpu_False wall_time: 0.000185 s Throughput: 0.486 GB/s
Benchmark: m_10_n_1000_p_0.5_use_gpu_True wall_time: 0.000335 s Throughput: 0.269 GB/s
Benchmark: m_10_n_1000_p_0.99_use_gpu_False wall_time: 0.000203 s Throughput: 0.83 GB/s
Benchmark: m_10_n_1000_p_0.99_use_gpu_True wall_time: 0.000346 s Throughput: 0.486 GB/s
Benchmark: m_10_n_10000_p_0.01_use_gpu_False wall_time: 0.00019 s Throughput: 0.609 GB/s
Benchmark: m_10_n_10000_p_0.01_use_gpu_True wall_time: 0.00028 s Throughput: 0.414 GB/s
Benchmark: m_10_n_10000_p_0.5_use_gpu_False wall_time: 0.00117 s Throughput: 0.771 GB/s
Benchmark: m_10_n_10000_p_0.5_use_gpu_True wall_time: 0.000426 s Throughput: 2.11 GB/s
Benchmark: m_10_n_10000_p_0.99_use_gpu_False wall_time: 0.0014 s Throughput: 1.2 GB/s
Benchmark: m_10_n_10000_p_0.99_use_gpu_True wall_time: 0.000482 s Throughput: 3.5 GB/s
Benchmark: m_10_n_100000_p_0.01_use_gpu_False wall_time: 0.00129 s Throughput: 0.899 GB/s
Benchmark: m_10_n_100000_p_0.01_use_gpu_True wall_time: 0.000336 s Throughput: 3.45 GB/s
Benchmark: m_10_n_100000_p_0.5_use_gpu_False wall_time: 0.0102 s Throughput: 0.885 GB/s
Benchmark: m_10_n_100000_p_0.5_use_gpu_True wall_time: 0.00136 s Throughput: 6.6 GB/s
Benchmark: m_10_n_100000_p_0.99_use_gpu_False wall_time: 0.0116 s Throughput: 1.45 GB/s
Benchmark: m_10_n_100000_p_0.99_use_gpu_True wall_time: 0.00233 s Throughput: 7.23 GB/s
Benchmark: m_10_n_1000000_p_0.01_use_gpu_False wall_time: 0.0111 s Throughput: 1.04 GB/s
Benchmark: m_10_n_1000000_p_0.01_use_gpu_True wall_time: 0.00109 s Throughput: 10.6 GB/s
Benchmark: m_10_n_1000000_p_0.5_use_gpu_False wall_time: 0.0895 s Throughput: 1.01 GB/s
Benchmark: m_10_n_1000000_p_0.5_use_gpu_True wall_time: 0.0103 s Throughput: 8.7 GB/s
Benchmark: m_10_n_1000000_p_0.99_use_gpu_False wall_time: 0.107 s Throughput: 1.58 GB/s
Benchmark: m_10_n_1000000_p_0.99_use_gpu_True wall_time: 0.0201 s Throughput: 8.39 GB/s
PiperOrigin-RevId: 160582709
-rw-r--r-- | tensorflow/contrib/cmake/CMakeLists.txt | 3 | ||||
-rw-r--r-- | tensorflow/contrib/cmake/external/cub.cmake | 29 | ||||
-rw-r--r-- | tensorflow/contrib/cmake/patches/cub/CMakeLists.txt | 3 | ||||
-rw-r--r-- | tensorflow/core/framework/register_types.h | 5 | ||||
-rw-r--r-- | tensorflow/core/kernels/BUILD | 5 | ||||
-rw-r--r-- | tensorflow/core/kernels/dense_update_ops.cc | 2 | ||||
-rw-r--r-- | tensorflow/core/kernels/dense_update_ops_gpu.cu.cc | 8 | ||||
-rw-r--r-- | tensorflow/core/kernels/resource_variable_ops.cc | 4 | ||||
-rw-r--r-- | tensorflow/core/kernels/where_op.cc | 259 | ||||
-rw-r--r-- | tensorflow/core/kernels/where_op.h | 61 | ||||
-rw-r--r-- | tensorflow/core/kernels/where_op_gpu.cu.cc | 253 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/where_op_test.py | 59 | ||||
-rw-r--r-- | tensorflow/tools/lib_package/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/tools/pip_package/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/workspace.bzl | 16 | ||||
-rw-r--r-- | third_party/cub.BUILD | 26 |
17 files changed, 665 insertions, 73 deletions
diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index fe66f8c49a..655e0b29a0 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -122,6 +122,7 @@ include(fft2d) include(highwayhash) include(protobuf) include(re2) +include(cub) if (tensorflow_BUILD_CC_TESTS) include(googletest) endif() @@ -151,6 +152,7 @@ set(tensorflow_EXTERNAL_DEPENDENCIES protobuf eigen gemmlowp + cub fft2d re2 ) @@ -170,6 +172,7 @@ include_directories( ${jsoncpp_INCLUDE_DIR} ${farmhash_INCLUDE_DIR} ${highwayhash_INCLUDE_DIR} + ${cub_INCLUDE_DIR} ${PROTOBUF_INCLUDE_DIRS} ${re2_INCLUDE_DIR} ) diff --git a/tensorflow/contrib/cmake/external/cub.cmake b/tensorflow/contrib/cmake/external/cub.cmake new file mode 100644 index 0000000000..54329dbd70 --- /dev/null +++ b/tensorflow/contrib/cmake/external/cub.cmake @@ -0,0 +1,29 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +include (ExternalProject) + +set(cub_URL https://github.com/NVlabs/cub/archive/1.6.4.zip) +set(cub_HASH SHA256=966d0c4f41e2bdc81aebf9ccfbf0baffaac5a74f00b826b06f4dee79b2bb8cee) +set(cub_BUILD ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub) +set(cub_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub) + +ExternalProject_Add(cub + PREFIX cub + URL ${cub_URL} + URL_HASH ${cub_HASH} + DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" + BUILD_IN_SOURCE 1 + PATCH_COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_CURRENT_SOURCE_DIR}/patches/cub/CMakeLists.txt ${cub_BUILD} + INSTALL_COMMAND "") diff --git a/tensorflow/contrib/cmake/patches/cub/CMakeLists.txt b/tensorflow/contrib/cmake/patches/cub/CMakeLists.txt new file mode 100644 index 0000000000..36890f0ce6 --- /dev/null +++ b/tensorflow/contrib/cmake/patches/cub/CMakeLists.txt @@ -0,0 +1,3 @@ +cmake_minimum_required(VERSION 2.8.3) + +project(cub) diff --git a/tensorflow/core/framework/register_types.h b/tensorflow/core/framework/register_types.h index 2f7b140295..b62fe647e2 100644 --- a/tensorflow/core/framework/register_types.h +++ b/tensorflow/core/framework/register_types.h @@ -167,10 +167,13 @@ limitations under the License. // Call "m" on POD and string types. #define TF_CALL_POD_STRING_TYPES(m) TF_CALL_POD_TYPES(m) TF_CALL_string(m) -// Call "m" on all types supported on GPU. +// Call "m" on all number types supported on GPU. #define TF_CALL_GPU_NUMBER_TYPES(m) \ TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m) +// Call "m" on all types supported on GPU. +#define TF_CALL_GPU_ALL_TYPES(m) TF_CALL_GPU_NUMBER_TYPES(m) TF_CALL_bool(m) + #define TF_CALL_GPU_NUMBER_TYPES_NO_HALF(m) TF_CALL_float(m) TF_CALL_double(m) // Call "m" on all quantized types. diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index f0dc3312c2..568e9b717d 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -763,7 +763,10 @@ tf_kernel_library( tf_kernel_library( name = "where_op", prefix = "where_op", - deps = ARRAY_DEPS, + deps = if_cuda([ + ":cuda_solvers", + "@cub_archive//:cub", + ]) + ARRAY_DEPS, ) tf_cc_test( diff --git a/tensorflow/core/kernels/dense_update_ops.cc b/tensorflow/core/kernels/dense_update_ops.cc index ef34946d96..af16c55a5d 100644 --- a/tensorflow/core/kernels/dense_update_ops.cc +++ b/tensorflow/core/kernels/dense_update_ops.cc @@ -154,7 +154,7 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS); Name("Assign").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ AssignOpT<GPUDevice, type>); -TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); +TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS); #undef REGISTER_GPU_KERNELS #endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/dense_update_ops_gpu.cu.cc b/tensorflow/core/kernels/dense_update_ops_gpu.cu.cc index 0f61506afe..cf2d493061 100644 --- a/tensorflow/core/kernels/dense_update_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/dense_update_ops_gpu.cu.cc @@ -27,11 +27,15 @@ typedef Eigen::GpuDevice GPUDevice; #define DEFINE_GPU_KERNELS(T) \ template struct functor::DenseUpdate<GPUDevice, T, ADD>; \ - template struct functor::DenseUpdate<GPUDevice, T, SUB>; \ - template struct functor::DenseUpdate<GPUDevice, T, ASSIGN>; + template struct functor::DenseUpdate<GPUDevice, T, SUB>; TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); #undef DEFINE_GPU_KERNELS +#define DEFINE_GPU_KERNELS(T) \ + template struct functor::DenseUpdate<GPUDevice, T, ASSIGN>; +TF_CALL_GPU_ALL_TYPES(DEFINE_GPU_KERNELS); +#undef DEFINE_GPU_KERNELS + } // end namespace tensorflow #endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc index db6d8c20ec..e108a59275 100644 --- a/tensorflow/core/kernels/resource_variable_ops.cc +++ b/tensorflow/core/kernels/resource_variable_ops.cc @@ -94,7 +94,7 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS); .HostMemory("resource"), \ ReadVariableOp<GPUDevice, type>); -TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); +TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS); #undef REGISTER_GPU_KERNELS #endif // GOOGLE_CUDA @@ -230,7 +230,7 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS); .HostMemory("resource"), \ AssignVariableOp<GPUDevice, type>); -TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); +TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS); #undef REGISTER_GPU_KERNELS #endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/where_op.cc b/tensorflow/core/kernels/where_op.cc index e56a498845..6fdcb331cc 100644 --- a/tensorflow/core/kernels/where_op.cc +++ b/tensorflow/core/kernels/where_op.cc @@ -17,9 +17,14 @@ limitations under the License. #define EIGEN_USE_THREADS +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA + #include "tensorflow/core/kernels/where_op.h" #include <memory> +#include <numeric> #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -27,43 +32,113 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" +#if GOOGLE_CUDA +#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" +#include "tensorflow/core/kernels/cuda_solvers.h" +#endif // GOOGLE_CUDA + namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -template <typename Device> -class WhereOp : public OpKernel { +namespace functor { + +template <> +struct NumTrue<CPUDevice, int64> { + static Status Compute(OpKernelContext* ctx, const CPUDevice& d, + TTypes<bool>::ConstFlat input, + TTypes<int64>::Scalar num_true) { + *num_true.data() = + std::accumulate(input.data(), input.data() + input.size(), 0); + return Status::OK(); + } +}; + +template <int DIMS, typename TIndex> +struct Where<CPUDevice, DIMS, TIndex> { + EIGEN_ALWAYS_INLINE static void WriteIndexRowMajor( + typename TTypes<int64>::Matrix output, + const typename Eigen::DSizes<TIndex, DIMS>& strides, TIndex true_n, + TIndex index) { + for (int i = 0; i < DIMS; ++i) { + output(true_n, i) = index / strides[i]; + index -= output(true_n, i) * strides[i]; + } + } + + EIGEN_ALWAYS_INLINE static Status Compute( + OpKernelContext* ctx, const CPUDevice& d, + typename TTypes<bool, DIMS>::ConstTensor input, + typename TTypes<int64>::Matrix output, TIndex* found_true) { + Eigen::DSizes<Eigen::DenseIndex, DIMS> dims = input.dimensions(); + Eigen::DSizes<TIndex, DIMS> strides; + + EIGEN_STATIC_ASSERT((static_cast<int>(decltype(input)::Layout) == + static_cast<int>(Eigen::RowMajor)), + INTERNAL_ERROR_INPUT_SHOULD_BE_ROWMAJOR); + + strides[DIMS - 1] = 1; + for (int i = DIMS - 2; i >= 0; --i) { + strides[i] = strides[i + 1] * dims[i + 1]; + } + + Eigen::DenseIndex output_size = output.dimension(0); + for (Eigen::DenseIndex n = 0; n < input.size(); ++n) { + if (input.data()[n]) { + if (FastBoundsCheck(*found_true, output_size)) { + WriteIndexRowMajor(output, strides, *found_true, n); + } + ++*found_true; + } + } + return Status::OK(); + } +}; + +} // namespace functor + +class WhereCPUOp : public OpKernel { public: - explicit WhereOp(OpKernelConstruction* context) : OpKernel(context) {} + explicit WhereCPUOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { const Tensor& input = context->input(0); const int input_dims = input.dims(); + Tensor num_true; OP_REQUIRES_OK( context, context->allocate_temp(DT_INT64, TensorShape({}), &num_true)); auto num_true_t = num_true.scalar<int64>(); - functor::NumTrue<Device>::Compute(context->eigen_device<Device>(), - input.flat<bool>(), num_true_t); + Status s = functor::NumTrue<CPUDevice, int64>::Compute( + context, context->eigen_device<CPUDevice>(), input.flat<bool>(), + num_true_t); + OP_REQUIRES_OK(context, s); TensorShape output_shape({num_true_t(), input_dims}); Tensor* output = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); -#define HANDLE_DIM(NDIM) \ - case NDIM: \ - found_true = functor::Where<Device, NDIM>::Compute( \ - context->eigen_device<Device>(), input.tensor<bool, NDIM>(), \ - output->matrix<int64>()); \ - break; - + // TODO(ebrevdo): Replace single-threaded copy with a + // multithreaded block copy by getting block counts above instead + // of a global NumTrue, then having each block filled in in + // separate threads below. int64 found_true = 0; + +#define HANDLE_DIM(NDIM) \ + case NDIM: { \ + Status s = functor::Where<CPUDevice, NDIM, int64>::Compute( \ + context, context->eigen_device<CPUDevice>(), \ + input.tensor<bool, NDIM>(), output->matrix<int64>(), &found_true); \ + OP_REQUIRES_OK(context, s); \ + } break; + switch (input_dims) { HANDLE_DIM(1); HANDLE_DIM(2); @@ -79,7 +154,7 @@ class WhereOp : public OpKernel { #undef HANDLE_DIM OP_REQUIRES( - context, num_true_t() == found_true, + context, found_true == num_true_t(), errors::InvalidArgument( "WhereOp: Race condition between counting the number of true " "elements and writing them. When counting, saw ", @@ -88,12 +163,162 @@ class WhereOp : public OpKernel { } private: - TF_DISALLOW_COPY_AND_ASSIGN(WhereOp); + TF_DISALLOW_COPY_AND_ASSIGN(WhereCPUOp); +}; + +REGISTER_KERNEL_BUILDER(Name("Where").Device(DEVICE_CPU), WhereCPUOp); + +#if GOOGLE_CUDA + +namespace functor { + +#define DECLARE_GPU_NUMTRUE(Tindex) \ + template <> \ + Status NumTrue<GPUDevice, Tindex>::Compute( \ + OpKernelContext* ctx, const GPUDevice& d, TTypes<bool>::ConstFlat input, \ + TTypes<Tindex>::Scalar num_true); \ + extern template struct NumTrue<GPUDevice, Tindex> + +DECLARE_GPU_NUMTRUE(int32); +DECLARE_GPU_NUMTRUE(int64); +#undef DECLARE_GPU_NUMTRUE + +#define DECLARE_GPU_WHERE_INDEX(Dims, Tindex) \ + template <> \ + Status Where<GPUDevice, Dims, Tindex>::Compute( \ + OpKernelContext* ctx, const GPUDevice& d, \ + typename TTypes<bool, Dims>::ConstTensor input, \ + typename TTypes<int64>::Matrix output, Tindex* found_true); \ + extern template struct Where<GPUDevice, Dims, Tindex>; +#define DECLARE_GPU_WHERE(Dims) \ + DECLARE_GPU_WHERE_INDEX(Dims, int32); \ + DECLARE_GPU_WHERE_INDEX(Dims, int64); + +DECLARE_GPU_WHERE(1); +DECLARE_GPU_WHERE(2); +DECLARE_GPU_WHERE(3); +DECLARE_GPU_WHERE(4); +DECLARE_GPU_WHERE(5); +#undef DECLARE_GPU_WHERE +#undef DECLARE_GPU_WHERE_INDEX + +} // namespace functor + +class WhereGPUOp : public AsyncOpKernel { + public: + explicit WhereGPUOp(OpKernelConstruction* context) : AsyncOpKernel(context) {} + + void ComputeAsync(OpKernelContext* context, DoneCallback done) override { + const Tensor& input = context->input(0); + const int input_dims = input.dims(); + + if (input.NumElements() < std::numeric_limits<int32>::max()) { + ComputeAsyncType<int32>(input, input_dims, context, done); + } else { + ComputeAsyncType<int64>(input, input_dims, context, done); + } + } + + template <typename Tindex> + void ComputeAsyncType(const Tensor& input, const int input_dims, + OpKernelContext* context, DoneCallback done) { + // Step 0: alloc nnz + // Step 1: call nnz kernel + // Step 2: copy nnz to host + // Step 3: call create_output + // Step 4: call where kernel + Tensor num_true; + OP_REQUIRES_OK_ASYNC(context, + context->allocate_temp(DataTypeToEnum<Tindex>::v(), + TensorShape({}), &num_true), + done); + + auto num_true_t = num_true.scalar<Tindex>(); + + perftools::gputools::DeviceMemoryBase num_true_ptr( + static_cast<void*>(num_true_t.data())); + // Push kernel to stream to get number of true elements. + const GPUDevice& d = context->eigen_device<GPUDevice>(); + Status s = functor::NumTrue<GPUDevice, Tindex>::Compute( + context, d, input.flat<bool>(), num_true_t); + OP_REQUIRES_OK_ASYNC(context, s, done); + + // Copy num_true to host; + ScratchSpace<Tindex> num_true_host(context, 1, /* on_host */ true); + + auto stream = context->op_device_context()->stream(); + OP_REQUIRES_ASYNC( + context, + stream + ->ThenMemcpy(num_true_host.mutable_data(), num_true_ptr, + sizeof(Tindex)) + .ok(), + errors::Internal("WhereOp: failed to copy num_true from device"), done); + + auto create_and_check_output = [context, &d, &input, input_dims, + num_true_host, done]() { + + Tindex num_true = *num_true_host.data(); + + // TODO(ebrevdo): Properly copy back found_true value to CPU for + // validation checking. Currently Where<GPUDevice>::Compute() + // does not perform this copy back to CPU. + Tindex found_true = -1; + + // Step 1: Allocate the output and perform the selection/copy. + Tensor* output; + OP_REQUIRES_OK_ASYNC(context, + context->allocate_output( + 0, TensorShape({num_true, input_dims}), &output), + done); + +#define HANDLE_DIM(NDIM) \ + case NDIM: { \ + Status s = functor::Where<GPUDevice, NDIM, Tindex>::Compute( \ + context, d, input.tensor<bool, NDIM>(), output->matrix<int64>(), \ + &found_true); \ + OP_REQUIRES_OK_ASYNC(context, s, done); \ + } break; + + switch (input_dims) { + HANDLE_DIM(1); + HANDLE_DIM(2); + HANDLE_DIM(3); + HANDLE_DIM(4); + HANDLE_DIM(5); + + default: + OP_REQUIRES_ASYNC( + context, false, + errors::InvalidArgument("WhereOp: Unhandled input dimensions: ", + input_dims), + done); + } +#undef HANDLE_DIM + + // TODO(ebrevdo): Fix the copy back to host. + + // OP_REQUIRES_ASYNC( + // context, found_true == num_true, + // errors::InvalidArgument( + // "WhereOp: Race condition between counting the number of true " + // "elements and writing them. When counting, saw ", + // num_true, " elements; but when writing their indices, saw ", + // found_true, " elements."), + // done); + + done(); + }; + context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute( + stream, create_and_check_output); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(WhereGPUOp); }; -#define REGISTER_WHERE() \ - REGISTER_KERNEL_BUILDER(Name("Where").Device(DEVICE_CPU), WhereOp<CPUDevice>); +REGISTER_KERNEL_BUILDER(Name("Where").Device(DEVICE_GPU), WhereGPUOp); -REGISTER_WHERE(); +#endif // GOOGLE_CUDA } // namespace tensorflow diff --git a/tensorflow/core/kernels/where_op.h b/tensorflow/core/kernels/where_op.h index aa27123714..e040325e3d 100644 --- a/tensorflow/core/kernels/where_op.h +++ b/tensorflow/core/kernels/where_op.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_KERNELS_WHERE_OP_H_ #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -25,55 +26,25 @@ namespace tensorflow { namespace functor { -template <typename Device> +template <typename Device, typename TIndex> struct NumTrue { - EIGEN_ALWAYS_INLINE static void Compute( - const Device& d, typename TTypes<bool>::ConstFlat input, - TTypes<int64>::Scalar num_true) { - num_true.device(d) = input.template cast<int64>().sum(); - } + EIGEN_ALWAYS_INLINE static Status Compute( + OpKernelContext* ctx, const Device& d, TTypes<bool>::ConstFlat input, + typename TTypes<TIndex>::Scalar num_true); }; -template <typename Device, int NDIM> +template <typename Device, int NDIM, typename TIndex> struct Where { - EIGEN_ALWAYS_INLINE static int64 Compute( - const Device& d, typename TTypes<bool, NDIM>::ConstTensor input, - typename TTypes<int64>::Matrix output) { - Eigen::DenseIndex true_n = 0; - Eigen::DSizes<Eigen::DenseIndex, NDIM> dims = input.dimensions(); - Eigen::DSizes<Eigen::DenseIndex, NDIM> strides; - - // Calculate strides for RowMajor order. - EIGEN_STATIC_ASSERT((static_cast<int>(decltype(input)::Layout) == - static_cast<int>(Eigen::RowMajor)), - INTERNAL_ERROR_INPUT_SHOULD_BE_ROWMAJOR); - - strides[NDIM - 1] = 1; - for (int i = NDIM - 2; i >= 0; --i) { - strides[i] = strides[i + 1] * dims[i + 1]; - } - - Eigen::DenseIndex output_size = output.dimension(0); - for (Eigen::DenseIndex n = 0; n < input.size(); ++n) { - if (input.data()[n]) { - if (TF_PREDICT_TRUE(true_n < output_size)) { - WriteIndexRowMajor(output, strides, true_n, n); - } - ++true_n; - } - } - return true_n; - } - - EIGEN_ALWAYS_INLINE static void WriteIndexRowMajor( - typename TTypes<int64>::Matrix output, - const Eigen::DSizes<Eigen::DenseIndex, NDIM>& strides, - Eigen::DenseIndex true_n, Eigen::DenseIndex index) { - for (int i = 0; i < NDIM; ++i) { - output(true_n, i) = index / strides[i]; - index %= strides[i]; - } - } + // Copies indices of true values in input into output. The pointer + // found_true should sit on the host. Compute should copy the + // number of true elements found into it. At the end, if + // *found_true != output.dimension(0), + // then the input may have changed between the initial counting of + // the true values and the call to Where. + EIGEN_ALWAYS_INLINE static Status Compute( + OpKernelContext* ctx, const Device& d, + typename TTypes<bool, NDIM>::ConstTensor input, + typename TTypes<int64>::Matrix output, TIndex* found_true); }; } // namespace functor diff --git a/tensorflow/core/kernels/where_op_gpu.cu.cc b/tensorflow/core/kernels/where_op_gpu.cu.cc new file mode 100644 index 0000000000..09e8be58fa --- /dev/null +++ b/tensorflow/core/kernels/where_op_gpu.cu.cc @@ -0,0 +1,253 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "external/cub_archive/cub/device/device_reduce.cuh" +#include "external/cub_archive/cub/device/device_select.cuh" +#include "external/cub_archive/cub/iterator/counting_input_iterator.cuh" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/where_op.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +namespace functor { + +template <int NDIM, typename TIndex> +__global__ void PropagateWhereIndicesKernel( + const TIndex output_rows, const typename Eigen::array<TIndex, NDIM> strides, + int64* output) { + // TODO(ebrevdo): Use a multi-dimensional loop, increasing the + // dimensions of individual indices manually, instead of relying on + // a scalar loop variable and using integer division. + CUDA_1D_KERNEL_LOOP(i, output_rows) { + TIndex index_value = ldg(output + NDIM * i); +#pragma unroll + for (int c = 0; c < NDIM; ++c) { + *(output + NDIM * i + c) = index_value / strides[c]; + index_value %= strides[c]; + } + } +} + +template <typename TIndex> +struct NumTrue<GPUDevice, TIndex> { + EIGEN_ALWAYS_INLINE static Status Compute( + OpKernelContext* ctx, const GPUDevice& d, TTypes<bool>::ConstFlat input, + typename TTypes<TIndex>::Scalar num_true) { + std::size_t temp_storage_bytes = 0; + + const bool* input_data = input.data(); + TIndex* num_true_data = num_true.data(); + + auto first_success = + cub::DeviceReduce::Sum(/*temp_storage*/ nullptr, temp_storage_bytes, + /*d_in*/ input_data, + /*d_out*/ num_true_data, + /*num_items*/ input.size(), + /*stream*/ d.stream()); + + if (first_success != cudaSuccess) { + return errors::Internal( + "WhereOp: Could not launch cub::DeviceReduce::Sum to calculate " + "temp_storage_bytes, status: ", + cudaGetErrorString(first_success)); + } + + Tensor temp_storage; + TF_RETURN_IF_ERROR(ctx->allocate_temp( + DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}), + &temp_storage)); + + auto second_success = cub::DeviceReduce::Sum( + /*temp_storage*/ temp_storage.flat<int8>().data(), temp_storage_bytes, + /*d_in*/ input_data, + /*d_out*/ num_true_data, + /*num_items*/ input.size(), + /*stream*/ d.stream()); + + if (second_success != cudaSuccess) { + return errors::Internal( + "WhereOp: Could not launch cub::DeviceReduce::Sum to count " + "number of true indices, status: ", + cudaGetErrorString(second_success)); + } + + return Status::OK(); + } +}; + +template struct NumTrue<GPUDevice, int32>; +template struct NumTrue<GPUDevice, int64>; + +template <int NDIM> +class WhereOutputIterator { + public: + // Required iterator traits + typedef WhereOutputIterator self_type; + typedef std::ptrdiff_t difference_type; + typedef void value_type; + typedef void pointer; + typedef int64& reference; + +#if (THRUST_VERSION >= 100700) + // Use Thrust's iterator categories so we can use these iterators in Thrust + // 1.7 (or newer) methods + typedef typename thrust::detail::iterator_facade_category< + thrust::device_system_tag, thrust::random_access_traversal_tag, + value_type, + reference>::type iterator_category; ///< The iterator category +#else + typedef std::random_access_iterator_tag + iterator_category; ///< The iterator category +#endif // THRUST_VERSION + + WhereOutputIterator(int64* ptr, const Eigen::DenseIndex max_row) + : ptr_(ptr), max_row_(max_row) {} + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int64& operator[](int n) const { + // If the selection mechanism finds too many true values (because + // the input tensor changed between allocation of output and now), + // we may accidentally try to write past the allowable memory. If + // valid is false, then we don't do this. Instead, we'll read off + // the number of items found in Flagged()'s d_num_selected_out at + // the end and confirm that it matches the number of rows of output. + const bool valid = FastBoundsCheck(n, max_row_); + return *(ptr_ + (valid ? (NDIM * n) : 0)); + } + + private: + int64* ptr_; + const Eigen::DenseIndex max_row_; +}; + +template <typename TIndex, int NDIM> +Eigen::array<TIndex, NDIM> CalculateStrides( + typename TTypes<bool, NDIM>::ConstTensor input) { + const Eigen::DSizes<Eigen::DenseIndex, NDIM> dims = input.dimensions(); + Eigen::array<TIndex, NDIM> strides; + EIGEN_STATIC_ASSERT((static_cast<int>(decltype(input)::Layout) == + static_cast<int>(Eigen::RowMajor)), + INTERNAL_ERROR_INPUT_SHOULD_BE_ROWMAJOR); + strides[NDIM - 1] = 1; + for (int i = NDIM - 2; i >= 0; --i) { + strides[i] = strides[i + 1] * dims[i + 1]; + } + return strides; +} + +template <int NDIM, typename Tindex> +struct Where<GPUDevice, NDIM, Tindex> { + EIGEN_ALWAYS_INLINE static Status Compute( + OpKernelContext* ctx, const GPUDevice& d, + typename TTypes<bool, NDIM>::ConstTensor input, + typename TTypes<int64>::Matrix output, Tindex* found_true_host) { + if (output.dimension(0) == 0) { + // Nothing to do. + return Status::OK(); + } + + std::size_t temp_storage_bytes = 0; + + cub::CountingInputIterator<Tindex> select_counter(0); + + Tensor found_true_t; + TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<Tindex>::v(), + TensorShape({}), &found_true_t)); + Tindex* found_true_device = found_true_t.scalar<Tindex>().data(); + + WhereOutputIterator<NDIM> output_iterator( + output.data(), + /* max_row */ output.dimension(0)); + + auto first_success = + cub::DeviceSelect::Flagged(/*temp_storage*/ nullptr, temp_storage_bytes, + /*d_in*/ select_counter, + /*d_flags*/ input.data(), + /*d_out*/ output_iterator, + /*d_num_selected_out*/ found_true_device, + /*num_items*/ input.size(), + /*stream*/ d.stream()); + if (first_success != cudaSuccess) { + return errors::Internal( + "WhereOp: Could not launch cub::DeviceSelect::Flagged to calculate " + "temp_storage_bytes, status: ", + cudaGetErrorString(first_success)); + } + + Tensor temp_storage; + TF_RETURN_IF_ERROR(ctx->allocate_temp( + DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}), + &temp_storage)); + + auto second_success = cub::DeviceSelect::Flagged( + /*temp_storage*/ temp_storage.flat<int8>().data(), temp_storage_bytes, + /*d_in*/ select_counter, + /*d_flags*/ input.data(), + /*d_out*/ output_iterator, + /*d_num_selected_out*/ found_true_device, + /*num_items*/ input.size(), + /*stream*/ d.stream()); + + if (second_success != cudaSuccess) { + return errors::Internal( + "WhereOp: Could not launch cub::DeviceSelect::Flagged to copy " + "indices out, status: ", + cudaGetErrorString(second_success)); + } + + // TODO(ebrevdo): Find a way to synchronously copy back data from + // found_true_device to *found_true_host. + + const Eigen::array<Tindex, NDIM> strides = + CalculateStrides<Tindex, NDIM>(input); + const Tindex output_rows = output.dimension(0); + CudaLaunchConfig config = GetCudaLaunchConfig(output_rows, d); + PropagateWhereIndicesKernel<NDIM, Tindex> + <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( + output_rows, strides, output.data()); + + return Status::OK(); + } +}; + +#define DECLARE_GPU_SPEC_INDEX(Dims, Tindex) \ + template struct Where<GPUDevice, Dims, Tindex> +#define DECLARE_GPU_SPEC(Dims) \ + DECLARE_GPU_SPEC_INDEX(Dims, int32); \ + DECLARE_GPU_SPEC_INDEX(Dims, int64) + +DECLARE_GPU_SPEC(1); +DECLARE_GPU_SPEC(2); +DECLARE_GPU_SPEC(3); +DECLARE_GPU_SPEC(4); +DECLARE_GPU_SPEC(5); + +#undef DECLARE_GPU_SPEC +#undef DECLARE_GPU_SPEC_INDEX + +} // namespace functor + +} // namespace tensorflow +#endif // GOOGLE_CUDA diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index d8e069105f..64d0b8fa52 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -900,7 +900,7 @@ tf_py_test( ], ) -tf_py_test( +cuda_py_test( name = "where_op_test", size = "small", srcs = ["where_op_test.py"], diff --git a/tensorflow/python/kernel_tests/where_op_test.py b/tensorflow/python/kernel_tests/where_op_test.py index b47159ae7f..a428d26996 100644 --- a/tensorflow/python/kernel_tests/where_op_test.py +++ b/tensorflow/python/kernel_tests/where_op_test.py @@ -18,17 +18,25 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import itertools +import sys + import numpy as np +from tensorflow.python.client import session from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test class WhereOpTest(test.TestCase): def _testWhere(self, x, truth, expected_err_re=None): - with self.test_session(): + with self.test_session(use_gpu=True): ans = array_ops.where(x) self.assertEqual([None, x.ndim], ans.get_shape().as_list()) if expected_err_re is None: @@ -39,12 +47,30 @@ class WhereOpTest(test.TestCase): ans.eval() def testWrongNumbers(self): - with self.test_session(): + with self.test_session(use_gpu=True): with self.assertRaises(ValueError): array_ops.where([False, True], [1, 2], None) with self.assertRaises(ValueError): array_ops.where([False, True], None, [1, 2]) + def testBasicVec(self): + x = np.asarray([True, False]) + truth = np.asarray([[0]], dtype=np.int64) + self._testWhere(x, truth) + + x = np.asarray([False, True, False]) + truth = np.asarray([[1]], dtype=np.int64) + self._testWhere(x, truth) + + x = np.asarray([False, False, True, False, True]) + truth = np.asarray([[2], [4]], dtype=np.int64) + self._testWhere(x, truth) + + def testRandomVec(self): + x = np.random.rand(1000000) > 0.5 + truth = np.vstack([np.where(x)[0].astype(np.int64)]).T + self._testWhere(x, truth) + def testBasicMat(self): x = np.asarray([[True, False], [True, False]]) @@ -67,10 +93,37 @@ class WhereOpTest(test.TestCase): def testThreeArgument(self): x = np.array([[-2, 3, -1], [1, -3, -3]]) np_val = np.where(x > 0, x * x, -x) - with self.test_session(): + with self.test_session(use_gpu=True): tf_val = array_ops.where(constant_op.constant(x) > 0, x * x, -x).eval() self.assertAllEqual(tf_val, np_val) +class WhereBenchmark(test.Benchmark): + + def benchmarkWhereCPU(self): + for (m, n, p, use_gpu) in itertools.product( + [10], + [10, 100, 1000, 10000, 100000, 1000000], + [0.01, 0.5, 0.99], + [False, True]): + name = "m_%d_n_%d_p_%g_use_gpu_%s" % (m, n, p, use_gpu) + device = "/%s:0" % ("gpu" if use_gpu else "cpu") + with ops.Graph().as_default(): + with ops.device(device): + x = random_ops.random_uniform((m, n), dtype=dtypes.float32) <= p + v = resource_variable_ops.ResourceVariable(x) + op = array_ops.where(v) + with session.Session() as sess: + v.initializer.run() + r = self.run_op_benchmark(sess, op, min_iters=100, name=name) + gb_processed_input = m * n / 1.0e9 + # approximate size of output: m*n*p int64s for each axis. + gb_processed_output = 2 * 8 * m * n * p / 1.0e9 + gb_processed = gb_processed_input + gb_processed_output + throughput = gb_processed / r["wall_time"] + print("Benchmark: %s \t wall_time: %0.03g s \t " + "Throughput: %0.03g GB/s" % (name, r["wall_time"], throughput)) + sys.stdout.flush() + if __name__ == "__main__": test.main() diff --git a/tensorflow/tools/lib_package/BUILD b/tensorflow/tools/lib_package/BUILD index 51ba3b7a0b..9da5d5cb5b 100644 --- a/tensorflow/tools/lib_package/BUILD +++ b/tensorflow/tools/lib_package/BUILD @@ -87,6 +87,7 @@ genrule( "//third_party/fft2d:LICENSE", "@boringssl//:LICENSE", "@com_googlesource_code_re2//:LICENSE", + "@cub_archive//:LICENSE.TXT", "@curl//:COPYING", "@eigen_archive//:COPYING.MPL2", "@farmhash_archive//:COPYING", @@ -117,6 +118,7 @@ genrule( "//third_party/fft2d:LICENSE", "@boringssl//:LICENSE", "@com_googlesource_code_re2//:LICENSE", + "@cub_archive//:LICENSE.TXT", "@curl//:COPYING", "@eigen_archive//:COPYING.MPL2", "@farmhash_archive//:COPYING", diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index 6a6ce5b9e3..0bd7cda1c5 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -95,6 +95,7 @@ filegroup( "//third_party/hadoop:LICENSE.txt", "@boringssl//:LICENSE", "@com_googlesource_code_re2//:LICENSE", + "@cub_archive//:LICENSE.TXT", "@curl//:COPYING", "@eigen_archive//:COPYING.MPL2", "@farmhash_archive//:COPYING", diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index a5172187f7..a6c5117e22 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -643,3 +643,19 @@ def tf_workspace(path_prefix="", tf_repo_name=""): strip_prefix = "pprof-c0fb62ec88c411cc91194465e54db2632845b650", build_file = str(Label("//third_party:pprof.BUILD")), ) + + native.new_http_archive( + name = "cub_archive", + urls = [ + "http://mirror.bazel.build/github.com/NVlabs/cub/archive/1.6.4.zip", + "https://github.com/NVlabs/cub/archive/1.6.4.zip", + ], + sha256 = "966d0c4f41e2bdc81aebf9ccfbf0baffaac5a74f00b826b06f4dee79b2bb8cee", + strip_prefix = "cub-1.6.4", + build_file = str(Label("//third_party:cub.BUILD")), + ) + + native.bind( + name = "cub", + actual = "@cub_archive//:cub", + ) diff --git a/third_party/cub.BUILD b/third_party/cub.BUILD new file mode 100644 index 0000000000..29159c9dad --- /dev/null +++ b/third_party/cub.BUILD @@ -0,0 +1,26 @@ +# Description: CUB library which is a set of primitives for GPU programming. + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # BSD + +exports_files(["LICENSE.TXT"]) + +load("@local_config_cuda//cuda:build_defs.bzl", "cuda_default_copts", "if_cuda") + +filegroup( + name = "cub_header_files", + srcs = glob([ + "cub/**", + ]), +) + +cc_library( + name = "cub", + hdrs = if_cuda([":cub_header_files"]), + deps = [ + "@local_config_cuda//cuda:cuda_headers", + ], +) |