aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Josh Levenberg <josh11b@tensorflow.org>2016-04-07 14:04:39 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-04-07 15:14:49 -0700
commit867227c1317d05ab1b6e5c2341212ed5741da049 (patch)
tree6d25ce0c08744ae28e638e6ba035005054a043b5
parent8eb7343c89223ae6c9d1fc3512607092556f885c (diff)
Fix int-overflow errors in concat cpu kernel.
Change: 119312877
-rw-r--r--tensorflow/core/kernels/concat_lib.h16
-rw-r--r--tensorflow/core/kernels/concat_lib_cpu.cc2
-rw-r--r--tensorflow/core/kernels/concat_lib_gpu.cu.cc49
-rw-r--r--tensorflow/core/kernels/concat_op.cc10
-rw-r--r--tensorflow/core/kernels/pack_op.cc13
-rw-r--r--tensorflow/core/kernels/tensor_array_ops.cc25
6 files changed, 92 insertions, 23 deletions
diff --git a/tensorflow/core/kernels/concat_lib.h b/tensorflow/core/kernels/concat_lib.h
index 77f92f5428..2f11986ae1 100644
--- a/tensorflow/core/kernels/concat_lib.h
+++ b/tensorflow/core/kernels/concat_lib.h
@@ -32,10 +32,18 @@ void ConcatCPU(DeviceBase* d,
// Assumes all inputs are nonempty
template <typename T>
-void ConcatGPU(const Eigen::GpuDevice& d,
- const std::vector<
- std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& inputs,
- typename TTypes<T, 2>::Matrix* output);
+void ConcatGPU32(
+ const Eigen::GpuDevice& d,
+ const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>&
+ inputs,
+ typename TTypes<T, 2>::Matrix* output);
+
+template <typename T>
+void ConcatGPU64(
+ const Eigen::GpuDevice& d,
+ const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>&
+ inputs,
+ typename TTypes<T, 2>::Matrix* output);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/concat_lib_cpu.cc b/tensorflow/core/kernels/concat_lib_cpu.cc
index f5431e6c0f..282ee533a5 100644
--- a/tensorflow/core/kernels/concat_lib_cpu.cc
+++ b/tensorflow/core/kernels/concat_lib_cpu.cc
@@ -41,7 +41,7 @@ void ConcatCPU(DeviceBase* d,
int num_inputs = inputs.size();
std::vector<ptrdiff_t> sizes;
sizes.reserve(num_inputs);
- int row_size = 0;
+ int64 row_size = 0;
for (int j = 0; j < num_inputs; ++j) {
sizes.push_back(inputs[j]->dimension(1));
row_size += sizes.back();
diff --git a/tensorflow/core/kernels/concat_lib_gpu.cu.cc b/tensorflow/core/kernels/concat_lib_gpu.cu.cc
index 018d551eb7..04b80b2fb8 100644
--- a/tensorflow/core/kernels/concat_lib_gpu.cu.cc
+++ b/tensorflow/core/kernels/concat_lib_gpu.cu.cc
@@ -32,13 +32,14 @@ namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
template <typename T>
-void ConcatGPU(const GPUDevice& d,
- const std::vector<
- std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& inputs,
- typename TTypes<T, 2>::Matrix* output) {
+void ConcatGPU32(
+ const GPUDevice& d,
+ const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>&
+ inputs,
+ typename TTypes<T, 2>::Matrix* output) {
Eigen::array<int32, 2> offset{0, 0};
for (int i = 0; i < inputs.size(); ++i) {
- Eigen::array<int32_t, 2> size;
+ Eigen::array<int32, 2> size;
size[0] = inputs[i]->dimension(0);
size[1] = inputs[i]->dimension(1);
To32Bit(*output).slice(offset, size).device(d) = To32Bit(*inputs[i]);
@@ -46,16 +47,44 @@ void ConcatGPU(const GPUDevice& d,
}
}
-#define REGISTER_GPU(T) \
- template void ConcatGPU<T>( \
+template <typename T>
+void ConcatGPU64(
+ const GPUDevice& d,
+ const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>&
+ inputs,
+ typename TTypes<T, 2>::Matrix* output) {
+ Eigen::array<int64, 2> offset{0, 0};
+ for (int i = 0; i < inputs.size(); ++i) {
+ Eigen::array<int64, 2> size;
+ size[0] = inputs[i]->dimension(0);
+ size[1] = inputs[i]->dimension(1);
+ output->slice(offset, size).device(d) = *inputs[i];
+ offset[1] += size[1];
+ }
+}
+
+#define REGISTER_GPU32(T) \
+ template void ConcatGPU32<T>( \
const GPUDevice& d, \
const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& \
inputs, \
typename TTypes<T, 2>::Matrix* output);
-TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
-REGISTER_GPU(bfloat16);
-#undef REGISTER_GPU
+#define REGISTER_GPU64(T) \
+ template void ConcatGPU64<T>( \
+ const GPUDevice& d, \
+ const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& \
+ inputs, \
+ typename TTypes<T, 2>::Matrix* output);
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU32);
+REGISTER_GPU32(bfloat16);
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU64);
+REGISTER_GPU64(bfloat16);
+
+#undef REGISTER_GPU32
+#undef REGISTER_GPU64
} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/concat_op.cc b/tensorflow/core/kernels/concat_op.cc
index 8122cee574..e9e18a9a37 100644
--- a/tensorflow/core/kernels/concat_op.cc
+++ b/tensorflow/core/kernels/concat_op.cc
@@ -15,6 +15,7 @@ limitations under the License.
// See docs in ../ops/array_ops.cc.
+#include <limits>
#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -119,7 +120,14 @@ class ConcatOp : public OpKernel {
int64 output_dim1 = output->NumElements() / inputs_flat_dim0;
auto output_flat = output->shaped<T, 2>({inputs_flat_dim0, output_dim1});
if (std::is_same<Device, GPUDevice>::value) {
- ConcatGPU<T>(c->eigen_gpu_device(), inputs_flat, &output_flat);
+ // Switching indexing to int64 might cause performance issues.
+ // Hence, we keep int32 indexing in the GPU kernel unless we need to
+ // switch to int64.
+ if (output->NumElements() < std::numeric_limits<int32>::max()) {
+ ConcatGPU64<T>(c->eigen_gpu_device(), inputs_flat, &output_flat);
+ } else {
+ ConcatGPU32<T>(c->eigen_gpu_device(), inputs_flat, &output_flat);
+ }
} else {
ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
}
diff --git a/tensorflow/core/kernels/pack_op.cc b/tensorflow/core/kernels/pack_op.cc
index 190063f2b8..e9fc0f438e 100644
--- a/tensorflow/core/kernels/pack_op.cc
+++ b/tensorflow/core/kernels/pack_op.cc
@@ -15,6 +15,7 @@ limitations under the License.
// See docs in ../ops/array_ops.cc.
+#include <limits>
#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -54,7 +55,6 @@ class PackOp : public OpKernel {
values[0].shape().DebugString(), " != values[", i,
"].shape = ", values[i].shape().DebugString()));
}
-
TensorShape output_shape(values[0].shape());
output_shape.InsertDim(0, num);
@@ -70,7 +70,7 @@ class PackOp : public OpKernel {
Tensor* output;
OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output));
- const int output_size = output->NumElements();
+ const int64 output_size = output->NumElements();
if (output_size > 0) {
auto output_flat = output->shaped<T, 2>({1, output_size});
@@ -83,7 +83,14 @@ class PackOp : public OpKernel {
values[i].shaped<T, 2>({1, values[i].NumElements()})));
}
if (std::is_same<Device, GPUDevice>::value) {
- ConcatGPU<T>(c->eigen_gpu_device(), inputs_flat, &output_flat);
+ // Switching indexing to int64 might cause performance issues.
+ // Hence, we keep int32 indexing in the GPU kernel unless we need to
+ // switch to int64.
+ if (output_size < std::numeric_limits<int32>::max()) {
+ ConcatGPU32<T>(c->eigen_gpu_device(), inputs_flat, &output_flat);
+ } else {
+ ConcatGPU64<T>(c->eigen_gpu_device(), inputs_flat, &output_flat);
+ }
} else {
ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
}
diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc
index 70ef00292a..55d1f514f6 100644
--- a/tensorflow/core/kernels/tensor_array_ops.cc
+++ b/tensorflow/core/kernels/tensor_array_ops.cc
@@ -17,7 +17,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
-#include <limits.h>
+#include <limits>
#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -440,7 +440,16 @@ class TensorArrayPackOp : public OpKernel {
}
if (std::is_same<Device, GPUDevice>::value) {
- ConcatGPU<T>(ctx->eigen_gpu_device(), input_tensors_flat, &output_flat);
+ // Switching indexing to int64 might cause performance issues.
+ // Hence, we keep int32 indexing in the GPU kernel unless we need to
+ // switch to int64.
+ if (output_shape.num_elements() < std::numeric_limits<int32>::max()) {
+ ConcatGPU32<T>(ctx->eigen_gpu_device(), input_tensors_flat,
+ &output_flat);
+ } else {
+ ConcatGPU64<T>(ctx->eigen_gpu_device(), input_tensors_flat,
+ &output_flat);
+ }
} else {
ConcatCPU<T>(ctx->device(), input_tensors_flat, &output_flat);
}
@@ -576,7 +585,6 @@ class TensorArrayConcatOp : public OpKernel {
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &output_tensor));
ConstMatrixVector input_tensors_flat;
input_tensors_flat.reserve(values.size());
-
for (size_t i = 0; i < values.size(); ++i) {
const Tensor* value_t = value_tensors[i];
if (value_t->NumElements() > 0) {
@@ -589,7 +597,16 @@ class TensorArrayConcatOp : public OpKernel {
auto output_flat =
output_tensor->shaped<T, 2>({1, output_shape.num_elements()});
if (std::is_same<Device, GPUDevice>::value) {
- ConcatGPU<T>(ctx->eigen_gpu_device(), input_tensors_flat, &output_flat);
+ // Switching indexing to int64 might cause performance issues.
+ // Hence, we keep int32 indexing in the GPU kernel unless we need to
+ // switch to int64.
+ if (output_shape.num_elements() < std::numeric_limits<int32>::max()) {
+ ConcatGPU32<T>(ctx->eigen_gpu_device(), input_tensors_flat,
+ &output_flat);
+ } else {
+ ConcatGPU64<T>(ctx->eigen_gpu_device(), input_tensors_flat,
+ &output_flat);
+ }
} else {
ConcatCPU<T>(ctx->device(), input_tensors_flat, &output_flat);
}