aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/dynamic_stitch_op.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-08-10 17:55:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-10 17:58:51 -0700
commite2a163a90561bef0accdd7a0f200f692d85e14c9 (patch)
tree60ff33ab90d20fabb8cc3d460d74208c20c12bd6 /tensorflow/core/kernels/dynamic_stitch_op.cc
parent9fba8c185164ddd4d9ff0483499dc158475883b2 (diff)
Merge code from PR #11940 with internal changes from cl/164796436, and update Python tests to also run on GPU.
PiperOrigin-RevId: 164929133
Diffstat (limited to 'tensorflow/core/kernels/dynamic_stitch_op.cc')
-rw-r--r--tensorflow/core/kernels/dynamic_stitch_op.cc141
1 files changed, 120 insertions, 21 deletions
diff --git a/tensorflow/core/kernels/dynamic_stitch_op.cc b/tensorflow/core/kernels/dynamic_stitch_op.cc
index f011f34fa8..99bcd90a4e 100644
--- a/tensorflow/core/kernels/dynamic_stitch_op.cc
+++ b/tensorflow/core/kernels/dynamic_stitch_op.cc
@@ -21,8 +21,17 @@ limitations under the License.
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/lib/core/threadpool.h"
+#ifdef GOOGLE_CUDA
+#include "tensorflow/core/kernels/cuda_device_array.h"
+#endif // GOOGLE_CUDA
+
namespace tensorflow {
+typedef Eigen::ThreadPoolDevice CPUDevice;
+#ifdef GOOGLE_CUDA
+typedef Eigen::GpuDevice GPUDevice;
+#endif // GOOGLE_CUDA
+
template <class T>
class DynamicStitchOpImplBase : public OpKernel {
public:
@@ -66,17 +75,24 @@ class DynamicStitchOpImplBase : public OpKernel {
void CheckArgsAndAllocateResult(OpKernelContext* c,
OpInputList* indices_inputs,
OpInputList* data_inputs, int* first_dim_size,
+ int* data_elements_size,
Tensor** result_ptr) {
// Find maximum index in the indices vectors
OP_REQUIRES_OK(c, c->input_list("indices", indices_inputs));
int32 max_index = -1;
+ if (data_elements_size) {
+ *data_elements_size = 0;
+ }
for (const Tensor& indices : *indices_inputs) {
if (indices.NumElements() > 0) {
Eigen::Tensor<int32, 0, Eigen::RowMajor> m =
indices.flat<int32>().maximum();
max_index = std::max(m(), max_index);
}
+ if (data_elements_size) {
+ *data_elements_size += indices.NumElements();
+ }
}
*first_dim_size = max_index + 1;
@@ -90,18 +106,19 @@ class DynamicStitchOpImplBase : public OpKernel {
const Tensor& data = (*data_inputs)[input_num];
OP_REQUIRES(
c, TensorShapeUtils::StartsWith(data.shape(), indices.shape()),
- errors::InvalidArgument("data[", input_num, "].shape = ",
- data.shape().DebugString(),
+ errors::InvalidArgument("data[", input_num,
+ "].shape = ", data.shape().DebugString(),
" does not start with indices[", input_num,
"].shape = ", indices.shape().DebugString()));
OP_REQUIRES(
c, input_num == 0 || SameExtraShape(data0, indices0, data, indices),
errors::InvalidArgument(
"Need data[0].shape[", indices0.dims(), ":] = data[", input_num,
- "].shape[", indices.dims(), ":], got data[0].shape = ",
- data0.shape().DebugString(), ", data[", input_num, "].shape = ",
- data.shape().DebugString(), ", indices[0].shape = ",
- indices0.shape().DebugString(), ", indices[", input_num,
+ "].shape[", indices.dims(),
+ ":], got data[0].shape = ", data0.shape().DebugString(),
+ ", data[", input_num, "].shape = ", data.shape().DebugString(),
+ ", indices[0].shape = ", indices0.shape().DebugString(),
+ ", indices[", input_num,
"].shape = ", indices.shape().DebugString()));
}
@@ -116,10 +133,90 @@ class DynamicStitchOpImplBase : public OpKernel {
}
};
+#if GOOGLE_CUDA
+
+template <typename T>
+void DynamicStitchGPUImpl(const Eigen::GpuDevice& gpu_device,
+ const int32 slice_size, const int32 first_dim_size,
+ const CudaDeviceArrayStruct<int>& input_indices,
+ const CudaDeviceArrayStruct<const T*>& input_ptrs,
+ T* output);
+
+template <class T>
+class DynamicStitchOpGPU : public DynamicStitchOpImplBase<T> {
+ public:
+ explicit DynamicStitchOpGPU(OpKernelConstruction* c)
+ : DynamicStitchOpImplBase<T>(c, "DynamicStitchOp") {}
+
+ void Compute(OpKernelContext* c) override {
+ OpInputList indices_inputs;
+ OpInputList data_inputs;
+ int first_dim_size;
+ int data_elements_size;
+ Tensor* merged = nullptr;
+ this->CheckArgsAndAllocateResult(c, &indices_inputs, &data_inputs,
+ &first_dim_size, &data_elements_size,
+ &merged);
+ if (!c->status().ok()) {
+ // Avoid segmentation faults if merged cannot be allocated and an error is
+ // passed back in the context.
+ return;
+ }
+
+ // TODO(jeff): Currently we leave uninitialized any portions of
+ // merged that aren't covered by an index in indices. What should we do?
+ if (first_dim_size > 0) {
+ // because the collision requirements, we have to deal with
+ // collion first before send data to gpu kernel.
+ // TODO(ekelsen): Instead of doing a serial scan on the CPU to pick the
+ // last of duplicated indices, it could instead be done of the GPU
+ // implicitly using atomics to make sure the last index is the final
+ // write.
+ const int slice_size = merged->flat_outer_dims<T>().dimension(1);
+ CudaDeviceArrayOnHost<int32> indices_flat(c, first_dim_size);
+ CudaDeviceArrayOnHost<const T*> data_flat(c, data_elements_size);
+ OP_REQUIRES_OK(c, indices_flat.Init());
+ OP_REQUIRES_OK(c, data_flat.Init());
+ // initialize the indices_flat (-1 represents missing indices)
+ for (int i = 0; i < first_dim_size; ++i) {
+ indices_flat.Set(i, -1);
+ }
+
+ // data_flat index
+ int32 idx = 0;
+ // sum of indices_inputs[i].NumElements() for compute indicies_flat value.
+ int32 base_size = 0;
+ for (int i = 0; i < indices_inputs.size(); ++i) {
+ auto indices_vec = indices_inputs[i].flat<int32>();
+ auto data_ptr_base = data_inputs[i].template flat<T>().data();
+ for (int j = 0; j < indices_vec.size(); ++j) {
+ // indices_flat's indices represent the indices of output.
+ // indices_flat's values represent the indices of input_data where the
+ // data located.
+ indices_flat.Set(indices_vec(j), base_size + j);
+ data_flat.Set(
+ idx, const_cast<T*>(reinterpret_cast<const T*>(data_ptr_base) +
+ j * slice_size));
+ ++idx;
+ }
+ base_size += indices_vec.size();
+ }
+ OP_REQUIRES_OK(c, indices_flat.Finalize());
+ OP_REQUIRES_OK(c, data_flat.Finalize());
+
+ auto output = merged->template flat<T>().data();
+ DynamicStitchGPUImpl<T>(c->eigen_gpu_device(), slice_size, first_dim_size,
+ indices_flat.data(), data_flat.data(), output);
+ }
+ }
+};
+
+#endif // GOOGLE_CUDA
+
template <class T, bool Parallel>
-class DynamicStitchOpImpl : public DynamicStitchOpImplBase<T> {
+class DynamicStitchOpImplCPU : public DynamicStitchOpImplBase<T> {
public:
- explicit DynamicStitchOpImpl(OpKernelConstruction* c)
+ explicit DynamicStitchOpImplCPU(OpKernelConstruction* c)
: DynamicStitchOpImplBase<T>(
c, (Parallel ? "ParallelDynamicStitchOp" : "DynamicStitchOp")) {}
@@ -129,7 +226,7 @@ class DynamicStitchOpImpl : public DynamicStitchOpImplBase<T> {
int first_dim_size;
Tensor* merged = nullptr;
this->CheckArgsAndAllocateResult(c, &indices_inputs, &data_inputs,
- &first_dim_size, &merged);
+ &first_dim_size, nullptr, &merged);
if (!c->status().ok()) {
// Avoid segmentation faults if merged cannot be allocated and an error is
// passed back in the context.
@@ -207,13 +304,13 @@ class DynamicStitchOpImpl : public DynamicStitchOpImplBase<T> {
// functionality later.
template <typename T>
-struct DynamicStitchOp : DynamicStitchOpImpl<T, false> {
- using DynamicStitchOpImpl<T, false>::DynamicStitchOpImpl;
+struct DynamicStitchOpCPU : DynamicStitchOpImplCPU<T, false> {
+ using DynamicStitchOpImplCPU<T, false>::DynamicStitchOpImplCPU;
};
template <typename T>
-struct ParallelDynamicStitchOp : DynamicStitchOpImpl<T, true> {
- using DynamicStitchOpImpl<T, true>::DynamicStitchOpImpl;
+struct ParallelDynamicStitchOpCPU : DynamicStitchOpImplCPU<T, true> {
+ using DynamicStitchOpImplCPU<T, true>::DynamicStitchOpImplCPU;
};
#define REGISTER_DYNAMIC_STITCH(type) \
@@ -221,12 +318,12 @@ struct ParallelDynamicStitchOp : DynamicStitchOpImpl<T, true> {
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.HostMemory("indices"), \
- DynamicStitchOp<type>) \
+ DynamicStitchOpCPU<type>) \
REGISTER_KERNEL_BUILDER(Name("ParallelDynamicStitch") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.HostMemory("indices"), \
- ParallelDynamicStitchOp<type>)
+ ParallelDynamicStitchOpCPU<type>)
TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH);
#undef REGISTER_DYNAMIC_STITCH
@@ -236,19 +333,21 @@ TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH);
REGISTER_KERNEL_BUILDER(Name("DynamicStitch") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
- .HostMemory("indices") \
- .HostMemory("data") \
- .HostMemory("merged"), \
- DynamicStitchOp<type>) \
+ .HostMemory("indices"), \
+ DynamicStitchOpGPU<type>) \
REGISTER_KERNEL_BUILDER(Name("ParallelDynamicStitch") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.HostMemory("indices") \
.HostMemory("data") \
.HostMemory("merged"), \
- ParallelDynamicStitchOp<type>)
+ ParallelDynamicStitchOpCPU<type>)
-TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH_GPU);
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_DYNAMIC_STITCH_GPU);
+TF_CALL_complex64(REGISTER_DYNAMIC_STITCH_GPU);
+TF_CALL_complex128(REGISTER_DYNAMIC_STITCH_GPU);
+TF_CALL_int64(REGISTER_DYNAMIC_STITCH_GPU);
+TF_CALL_int32(REGISTER_DYNAMIC_STITCH_GPU);
#undef REGISTER_DYNAMIC_STITCH_GPU
#endif // GOOGLE_CUDA