aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/concat_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/concat_op.cc')
-rw-r--r--tensorflow/core/kernels/concat_op.cc10
1 files changed, 9 insertions, 1 deletions
diff --git a/tensorflow/core/kernels/concat_op.cc b/tensorflow/core/kernels/concat_op.cc
index 8122cee574..e9e18a9a37 100644
--- a/tensorflow/core/kernels/concat_op.cc
+++ b/tensorflow/core/kernels/concat_op.cc
@@ -15,6 +15,7 @@ limitations under the License.
// See docs in ../ops/array_ops.cc.
+#include <limits>
#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -119,7 +120,14 @@ class ConcatOp : public OpKernel {
int64 output_dim1 = output->NumElements() / inputs_flat_dim0;
auto output_flat = output->shaped<T, 2>({inputs_flat_dim0, output_dim1});
if (std::is_same<Device, GPUDevice>::value) {
- ConcatGPU<T>(c->eigen_gpu_device(), inputs_flat, &output_flat);
+ // Switching indexing to int64 might cause performance issues.
+ // Hence, we keep int32 indexing in the GPU kernel unless we need to
+ // switch to int64.
+ if (output->NumElements() < std::numeric_limits<int32>::max()) {
+ ConcatGPU64<T>(c->eigen_gpu_device(), inputs_flat, &output_flat);
+ } else {
+ ConcatGPU32<T>(c->eigen_gpu_device(), inputs_flat, &output_flat);
+ }
} else {
ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
}