diff options
Diffstat (limited to 'tensorflow/core/kernels/concat_op.cc')
-rw-r--r-- | tensorflow/core/kernels/concat_op.cc | 10 |
1 files changed, 9 insertions, 1 deletions
diff --git a/tensorflow/core/kernels/concat_op.cc b/tensorflow/core/kernels/concat_op.cc index 8122cee574..e9e18a9a37 100644 --- a/tensorflow/core/kernels/concat_op.cc +++ b/tensorflow/core/kernels/concat_op.cc @@ -15,6 +15,7 @@ limitations under the License. // See docs in ../ops/array_ops.cc. +#include <limits> #include <vector> #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" @@ -119,7 +120,14 @@ class ConcatOp : public OpKernel { int64 output_dim1 = output->NumElements() / inputs_flat_dim0; auto output_flat = output->shaped<T, 2>({inputs_flat_dim0, output_dim1}); if (std::is_same<Device, GPUDevice>::value) { - ConcatGPU<T>(c->eigen_gpu_device(), inputs_flat, &output_flat); + // Switching indexing to int64 might cause performance issues. + // Hence, we keep int32 indexing in the GPU kernel unless we need to + // switch to int64. + if (output->NumElements() < std::numeric_limits<int32>::max()) { + ConcatGPU64<T>(c->eigen_gpu_device(), inputs_flat, &output_flat); + } else { + ConcatGPU32<T>(c->eigen_gpu_device(), inputs_flat, &output_flat); + } } else { ConcatCPU<T>(c->device(), inputs_flat, &output_flat); } |