diff options
author | 2017-08-10 17:55:10 -0700 | |
---|---|---|
committer | 2017-08-10 17:58:51 -0700 | |
commit | e2a163a90561bef0accdd7a0f200f692d85e14c9 (patch) | |
tree | 60ff33ab90d20fabb8cc3d460d74208c20c12bd6 /tensorflow/core/kernels/dynamic_stitch_op.cc | |
parent | 9fba8c185164ddd4d9ff0483499dc158475883b2 (diff) |
Merge code from PR #11940 with internal changes from cl/164796436, and update Python tests to also run on GPU.
PiperOrigin-RevId: 164929133
Diffstat (limited to 'tensorflow/core/kernels/dynamic_stitch_op.cc')
-rw-r--r-- | tensorflow/core/kernels/dynamic_stitch_op.cc | 141 |
1 files changed, 120 insertions, 21 deletions
diff --git a/tensorflow/core/kernels/dynamic_stitch_op.cc b/tensorflow/core/kernels/dynamic_stitch_op.cc index f011f34fa8..99bcd90a4e 100644 --- a/tensorflow/core/kernels/dynamic_stitch_op.cc +++ b/tensorflow/core/kernels/dynamic_stitch_op.cc @@ -21,8 +21,17 @@ limitations under the License. #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/lib/core/threadpool.h" +#ifdef GOOGLE_CUDA +#include "tensorflow/core/kernels/cuda_device_array.h" +#endif // GOOGLE_CUDA + namespace tensorflow { +typedef Eigen::ThreadPoolDevice CPUDevice; +#ifdef GOOGLE_CUDA +typedef Eigen::GpuDevice GPUDevice; +#endif // GOOGLE_CUDA + template <class T> class DynamicStitchOpImplBase : public OpKernel { public: @@ -66,17 +75,24 @@ class DynamicStitchOpImplBase : public OpKernel { void CheckArgsAndAllocateResult(OpKernelContext* c, OpInputList* indices_inputs, OpInputList* data_inputs, int* first_dim_size, + int* data_elements_size, Tensor** result_ptr) { // Find maximum index in the indices vectors OP_REQUIRES_OK(c, c->input_list("indices", indices_inputs)); int32 max_index = -1; + if (data_elements_size) { + *data_elements_size = 0; + } for (const Tensor& indices : *indices_inputs) { if (indices.NumElements() > 0) { Eigen::Tensor<int32, 0, Eigen::RowMajor> m = indices.flat<int32>().maximum(); max_index = std::max(m(), max_index); } + if (data_elements_size) { + *data_elements_size += indices.NumElements(); + } } *first_dim_size = max_index + 1; @@ -90,18 +106,19 @@ class DynamicStitchOpImplBase : public OpKernel { const Tensor& data = (*data_inputs)[input_num]; OP_REQUIRES( c, TensorShapeUtils::StartsWith(data.shape(), indices.shape()), - errors::InvalidArgument("data[", input_num, "].shape = ", - data.shape().DebugString(), + errors::InvalidArgument("data[", input_num, + "].shape = ", data.shape().DebugString(), " does not start with indices[", input_num, "].shape = ", indices.shape().DebugString())); OP_REQUIRES( c, input_num == 0 || SameExtraShape(data0, indices0, data, indices), errors::InvalidArgument( "Need data[0].shape[", indices0.dims(), ":] = data[", input_num, - "].shape[", indices.dims(), ":], got data[0].shape = ", - data0.shape().DebugString(), ", data[", input_num, "].shape = ", - data.shape().DebugString(), ", indices[0].shape = ", - indices0.shape().DebugString(), ", indices[", input_num, + "].shape[", indices.dims(), + ":], got data[0].shape = ", data0.shape().DebugString(), + ", data[", input_num, "].shape = ", data.shape().DebugString(), + ", indices[0].shape = ", indices0.shape().DebugString(), + ", indices[", input_num, "].shape = ", indices.shape().DebugString())); } @@ -116,10 +133,90 @@ class DynamicStitchOpImplBase : public OpKernel { } }; +#if GOOGLE_CUDA + +template <typename T> +void DynamicStitchGPUImpl(const Eigen::GpuDevice& gpu_device, + const int32 slice_size, const int32 first_dim_size, + const CudaDeviceArrayStruct<int>& input_indices, + const CudaDeviceArrayStruct<const T*>& input_ptrs, + T* output); + +template <class T> +class DynamicStitchOpGPU : public DynamicStitchOpImplBase<T> { + public: + explicit DynamicStitchOpGPU(OpKernelConstruction* c) + : DynamicStitchOpImplBase<T>(c, "DynamicStitchOp") {} + + void Compute(OpKernelContext* c) override { + OpInputList indices_inputs; + OpInputList data_inputs; + int first_dim_size; + int data_elements_size; + Tensor* merged = nullptr; + this->CheckArgsAndAllocateResult(c, &indices_inputs, &data_inputs, + &first_dim_size, &data_elements_size, + &merged); + if (!c->status().ok()) { + // Avoid segmentation faults if merged cannot be allocated and an error is + // passed back in the context. + return; + } + + // TODO(jeff): Currently we leave uninitialized any portions of + // merged that aren't covered by an index in indices. What should we do? + if (first_dim_size > 0) { + // because the collision requirements, we have to deal with + // collion first before send data to gpu kernel. + // TODO(ekelsen): Instead of doing a serial scan on the CPU to pick the + // last of duplicated indices, it could instead be done of the GPU + // implicitly using atomics to make sure the last index is the final + // write. + const int slice_size = merged->flat_outer_dims<T>().dimension(1); + CudaDeviceArrayOnHost<int32> indices_flat(c, first_dim_size); + CudaDeviceArrayOnHost<const T*> data_flat(c, data_elements_size); + OP_REQUIRES_OK(c, indices_flat.Init()); + OP_REQUIRES_OK(c, data_flat.Init()); + // initialize the indices_flat (-1 represents missing indices) + for (int i = 0; i < first_dim_size; ++i) { + indices_flat.Set(i, -1); + } + + // data_flat index + int32 idx = 0; + // sum of indices_inputs[i].NumElements() for compute indicies_flat value. + int32 base_size = 0; + for (int i = 0; i < indices_inputs.size(); ++i) { + auto indices_vec = indices_inputs[i].flat<int32>(); + auto data_ptr_base = data_inputs[i].template flat<T>().data(); + for (int j = 0; j < indices_vec.size(); ++j) { + // indices_flat's indices represent the indices of output. + // indices_flat's values represent the indices of input_data where the + // data located. + indices_flat.Set(indices_vec(j), base_size + j); + data_flat.Set( + idx, const_cast<T*>(reinterpret_cast<const T*>(data_ptr_base) + + j * slice_size)); + ++idx; + } + base_size += indices_vec.size(); + } + OP_REQUIRES_OK(c, indices_flat.Finalize()); + OP_REQUIRES_OK(c, data_flat.Finalize()); + + auto output = merged->template flat<T>().data(); + DynamicStitchGPUImpl<T>(c->eigen_gpu_device(), slice_size, first_dim_size, + indices_flat.data(), data_flat.data(), output); + } + } +}; + +#endif // GOOGLE_CUDA + template <class T, bool Parallel> -class DynamicStitchOpImpl : public DynamicStitchOpImplBase<T> { +class DynamicStitchOpImplCPU : public DynamicStitchOpImplBase<T> { public: - explicit DynamicStitchOpImpl(OpKernelConstruction* c) + explicit DynamicStitchOpImplCPU(OpKernelConstruction* c) : DynamicStitchOpImplBase<T>( c, (Parallel ? "ParallelDynamicStitchOp" : "DynamicStitchOp")) {} @@ -129,7 +226,7 @@ class DynamicStitchOpImpl : public DynamicStitchOpImplBase<T> { int first_dim_size; Tensor* merged = nullptr; this->CheckArgsAndAllocateResult(c, &indices_inputs, &data_inputs, - &first_dim_size, &merged); + &first_dim_size, nullptr, &merged); if (!c->status().ok()) { // Avoid segmentation faults if merged cannot be allocated and an error is // passed back in the context. @@ -207,13 +304,13 @@ class DynamicStitchOpImpl : public DynamicStitchOpImplBase<T> { // functionality later. template <typename T> -struct DynamicStitchOp : DynamicStitchOpImpl<T, false> { - using DynamicStitchOpImpl<T, false>::DynamicStitchOpImpl; +struct DynamicStitchOpCPU : DynamicStitchOpImplCPU<T, false> { + using DynamicStitchOpImplCPU<T, false>::DynamicStitchOpImplCPU; }; template <typename T> -struct ParallelDynamicStitchOp : DynamicStitchOpImpl<T, true> { - using DynamicStitchOpImpl<T, true>::DynamicStitchOpImpl; +struct ParallelDynamicStitchOpCPU : DynamicStitchOpImplCPU<T, true> { + using DynamicStitchOpImplCPU<T, true>::DynamicStitchOpImplCPU; }; #define REGISTER_DYNAMIC_STITCH(type) \ @@ -221,12 +318,12 @@ struct ParallelDynamicStitchOp : DynamicStitchOpImpl<T, true> { .Device(DEVICE_CPU) \ .TypeConstraint<type>("T") \ .HostMemory("indices"), \ - DynamicStitchOp<type>) \ + DynamicStitchOpCPU<type>) \ REGISTER_KERNEL_BUILDER(Name("ParallelDynamicStitch") \ .Device(DEVICE_CPU) \ .TypeConstraint<type>("T") \ .HostMemory("indices"), \ - ParallelDynamicStitchOp<type>) + ParallelDynamicStitchOpCPU<type>) TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH); #undef REGISTER_DYNAMIC_STITCH @@ -236,19 +333,21 @@ TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH); REGISTER_KERNEL_BUILDER(Name("DynamicStitch") \ .Device(DEVICE_GPU) \ .TypeConstraint<type>("T") \ - .HostMemory("indices") \ - .HostMemory("data") \ - .HostMemory("merged"), \ - DynamicStitchOp<type>) \ + .HostMemory("indices"), \ + DynamicStitchOpGPU<type>) \ REGISTER_KERNEL_BUILDER(Name("ParallelDynamicStitch") \ .Device(DEVICE_GPU) \ .TypeConstraint<type>("T") \ .HostMemory("indices") \ .HostMemory("data") \ .HostMemory("merged"), \ - ParallelDynamicStitchOp<type>) + ParallelDynamicStitchOpCPU<type>) -TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH_GPU); +TF_CALL_GPU_NUMBER_TYPES(REGISTER_DYNAMIC_STITCH_GPU); +TF_CALL_complex64(REGISTER_DYNAMIC_STITCH_GPU); +TF_CALL_complex128(REGISTER_DYNAMIC_STITCH_GPU); +TF_CALL_int64(REGISTER_DYNAMIC_STITCH_GPU); +TF_CALL_int32(REGISTER_DYNAMIC_STITCH_GPU); #undef REGISTER_DYNAMIC_STITCH_GPU #endif // GOOGLE_CUDA |