diff options
author | 2016-04-07 14:04:39 -0800 | |
---|---|---|
committer | 2016-04-07 15:14:49 -0700 | |
commit | 867227c1317d05ab1b6e5c2341212ed5741da049 (patch) | |
tree | 6d25ce0c08744ae28e638e6ba035005054a043b5 | |
parent | 8eb7343c89223ae6c9d1fc3512607092556f885c (diff) |
Fix int-overflow errors in concat cpu kernel.
Change: 119312877
-rw-r--r-- | tensorflow/core/kernels/concat_lib.h | 16 | ||||
-rw-r--r-- | tensorflow/core/kernels/concat_lib_cpu.cc | 2 | ||||
-rw-r--r-- | tensorflow/core/kernels/concat_lib_gpu.cu.cc | 49 | ||||
-rw-r--r-- | tensorflow/core/kernels/concat_op.cc | 10 | ||||
-rw-r--r-- | tensorflow/core/kernels/pack_op.cc | 13 | ||||
-rw-r--r-- | tensorflow/core/kernels/tensor_array_ops.cc | 25 |
6 files changed, 92 insertions, 23 deletions
diff --git a/tensorflow/core/kernels/concat_lib.h b/tensorflow/core/kernels/concat_lib.h index 77f92f5428..2f11986ae1 100644 --- a/tensorflow/core/kernels/concat_lib.h +++ b/tensorflow/core/kernels/concat_lib.h @@ -32,10 +32,18 @@ void ConcatCPU(DeviceBase* d, // Assumes all inputs are nonempty template <typename T> -void ConcatGPU(const Eigen::GpuDevice& d, - const std::vector< - std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& inputs, - typename TTypes<T, 2>::Matrix* output); +void ConcatGPU32( + const Eigen::GpuDevice& d, + const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& + inputs, + typename TTypes<T, 2>::Matrix* output); + +template <typename T> +void ConcatGPU64( + const Eigen::GpuDevice& d, + const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& + inputs, + typename TTypes<T, 2>::Matrix* output); } // namespace tensorflow diff --git a/tensorflow/core/kernels/concat_lib_cpu.cc b/tensorflow/core/kernels/concat_lib_cpu.cc index f5431e6c0f..282ee533a5 100644 --- a/tensorflow/core/kernels/concat_lib_cpu.cc +++ b/tensorflow/core/kernels/concat_lib_cpu.cc @@ -41,7 +41,7 @@ void ConcatCPU(DeviceBase* d, int num_inputs = inputs.size(); std::vector<ptrdiff_t> sizes; sizes.reserve(num_inputs); - int row_size = 0; + int64 row_size = 0; for (int j = 0; j < num_inputs; ++j) { sizes.push_back(inputs[j]->dimension(1)); row_size += sizes.back(); diff --git a/tensorflow/core/kernels/concat_lib_gpu.cu.cc b/tensorflow/core/kernels/concat_lib_gpu.cu.cc index 018d551eb7..04b80b2fb8 100644 --- a/tensorflow/core/kernels/concat_lib_gpu.cu.cc +++ b/tensorflow/core/kernels/concat_lib_gpu.cu.cc @@ -32,13 +32,14 @@ namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; template <typename T> -void ConcatGPU(const GPUDevice& d, - const std::vector< - std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& inputs, - typename TTypes<T, 2>::Matrix* output) { +void ConcatGPU32( + const GPUDevice& d, + const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& + inputs, + typename TTypes<T, 2>::Matrix* output) { Eigen::array<int32, 2> offset{0, 0}; for (int i = 0; i < inputs.size(); ++i) { - Eigen::array<int32_t, 2> size; + Eigen::array<int32, 2> size; size[0] = inputs[i]->dimension(0); size[1] = inputs[i]->dimension(1); To32Bit(*output).slice(offset, size).device(d) = To32Bit(*inputs[i]); @@ -46,16 +47,44 @@ void ConcatGPU(const GPUDevice& d, } } -#define REGISTER_GPU(T) \ - template void ConcatGPU<T>( \ +template <typename T> +void ConcatGPU64( + const GPUDevice& d, + const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& + inputs, + typename TTypes<T, 2>::Matrix* output) { + Eigen::array<int64, 2> offset{0, 0}; + for (int i = 0; i < inputs.size(); ++i) { + Eigen::array<int64, 2> size; + size[0] = inputs[i]->dimension(0); + size[1] = inputs[i]->dimension(1); + output->slice(offset, size).device(d) = *inputs[i]; + offset[1] += size[1]; + } +} + +#define REGISTER_GPU32(T) \ + template void ConcatGPU32<T>( \ const GPUDevice& d, \ const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& \ inputs, \ typename TTypes<T, 2>::Matrix* output); -TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); -REGISTER_GPU(bfloat16); -#undef REGISTER_GPU +#define REGISTER_GPU64(T) \ + template void ConcatGPU64<T>( \ + const GPUDevice& d, \ + const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& \ + inputs, \ + typename TTypes<T, 2>::Matrix* output); + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU32); +REGISTER_GPU32(bfloat16); + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU64); +REGISTER_GPU64(bfloat16); + +#undef REGISTER_GPU32 +#undef REGISTER_GPU64 } // end namespace tensorflow diff --git a/tensorflow/core/kernels/concat_op.cc b/tensorflow/core/kernels/concat_op.cc index 8122cee574..e9e18a9a37 100644 --- a/tensorflow/core/kernels/concat_op.cc +++ b/tensorflow/core/kernels/concat_op.cc @@ -15,6 +15,7 @@ limitations under the License. // See docs in ../ops/array_ops.cc. +#include <limits> #include <vector> #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" @@ -119,7 +120,14 @@ class ConcatOp : public OpKernel { int64 output_dim1 = output->NumElements() / inputs_flat_dim0; auto output_flat = output->shaped<T, 2>({inputs_flat_dim0, output_dim1}); if (std::is_same<Device, GPUDevice>::value) { - ConcatGPU<T>(c->eigen_gpu_device(), inputs_flat, &output_flat); + // Switching indexing to int64 might cause performance issues. + // Hence, we keep int32 indexing in the GPU kernel unless we need to + // switch to int64. + if (output->NumElements() < std::numeric_limits<int32>::max()) { + ConcatGPU64<T>(c->eigen_gpu_device(), inputs_flat, &output_flat); + } else { + ConcatGPU32<T>(c->eigen_gpu_device(), inputs_flat, &output_flat); + } } else { ConcatCPU<T>(c->device(), inputs_flat, &output_flat); } diff --git a/tensorflow/core/kernels/pack_op.cc b/tensorflow/core/kernels/pack_op.cc index 190063f2b8..e9fc0f438e 100644 --- a/tensorflow/core/kernels/pack_op.cc +++ b/tensorflow/core/kernels/pack_op.cc @@ -15,6 +15,7 @@ limitations under the License. // See docs in ../ops/array_ops.cc. +#include <limits> #include <vector> #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" @@ -54,7 +55,6 @@ class PackOp : public OpKernel { values[0].shape().DebugString(), " != values[", i, "].shape = ", values[i].shape().DebugString())); } - TensorShape output_shape(values[0].shape()); output_shape.InsertDim(0, num); @@ -70,7 +70,7 @@ class PackOp : public OpKernel { Tensor* output; OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output)); - const int output_size = output->NumElements(); + const int64 output_size = output->NumElements(); if (output_size > 0) { auto output_flat = output->shaped<T, 2>({1, output_size}); @@ -83,7 +83,14 @@ class PackOp : public OpKernel { values[i].shaped<T, 2>({1, values[i].NumElements()}))); } if (std::is_same<Device, GPUDevice>::value) { - ConcatGPU<T>(c->eigen_gpu_device(), inputs_flat, &output_flat); + // Switching indexing to int64 might cause performance issues. + // Hence, we keep int32 indexing in the GPU kernel unless we need to + // switch to int64. + if (output_size < std::numeric_limits<int32>::max()) { + ConcatGPU32<T>(c->eigen_gpu_device(), inputs_flat, &output_flat); + } else { + ConcatGPU64<T>(c->eigen_gpu_device(), inputs_flat, &output_flat); + } } else { ConcatCPU<T>(c->device(), inputs_flat, &output_flat); } diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc index 70ef00292a..55d1f514f6 100644 --- a/tensorflow/core/kernels/tensor_array_ops.cc +++ b/tensorflow/core/kernels/tensor_array_ops.cc @@ -17,7 +17,7 @@ limitations under the License. #define EIGEN_USE_THREADS -#include <limits.h> +#include <limits> #include <vector> #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" @@ -440,7 +440,16 @@ class TensorArrayPackOp : public OpKernel { } if (std::is_same<Device, GPUDevice>::value) { - ConcatGPU<T>(ctx->eigen_gpu_device(), input_tensors_flat, &output_flat); + // Switching indexing to int64 might cause performance issues. + // Hence, we keep int32 indexing in the GPU kernel unless we need to + // switch to int64. + if (output_shape.num_elements() < std::numeric_limits<int32>::max()) { + ConcatGPU32<T>(ctx->eigen_gpu_device(), input_tensors_flat, + &output_flat); + } else { + ConcatGPU64<T>(ctx->eigen_gpu_device(), input_tensors_flat, + &output_flat); + } } else { ConcatCPU<T>(ctx->device(), input_tensors_flat, &output_flat); } @@ -576,7 +585,6 @@ class TensorArrayConcatOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &output_tensor)); ConstMatrixVector input_tensors_flat; input_tensors_flat.reserve(values.size()); - for (size_t i = 0; i < values.size(); ++i) { const Tensor* value_t = value_tensors[i]; if (value_t->NumElements() > 0) { @@ -589,7 +597,16 @@ class TensorArrayConcatOp : public OpKernel { auto output_flat = output_tensor->shaped<T, 2>({1, output_shape.num_elements()}); if (std::is_same<Device, GPUDevice>::value) { - ConcatGPU<T>(ctx->eigen_gpu_device(), input_tensors_flat, &output_flat); + // Switching indexing to int64 might cause performance issues. + // Hence, we keep int32 indexing in the GPU kernel unless we need to + // switch to int64. + if (output_shape.num_elements() < std::numeric_limits<int32>::max()) { + ConcatGPU32<T>(ctx->eigen_gpu_device(), input_tensors_flat, + &output_flat); + } else { + ConcatGPU64<T>(ctx->eigen_gpu_device(), input_tensors_flat, + &output_flat); + } } else { ConcatCPU<T>(ctx->device(), input_tensors_flat, &output_flat); } |