aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-18 12:22:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-18 12:25:47 -0700
commit192f1c24ec6692342391c03bb620f5de1af9de3b (patch)
tree5fca708751bf3ec66db86317115eee7d04a52ef7
parentf1603b7893f922dfe64244c6bae9b93d7d594437 (diff)
Fixed work size computation in Split and SplitV ops to avoid integer overflow.
PiperOrigin-RevId: 172637818
-rw-r--r--tensorflow/core/kernels/split_op.cc8
-rw-r--r--tensorflow/core/kernels/split_v_op.cc8
2 files changed, 8 insertions, 8 deletions
diff --git a/tensorflow/core/kernels/split_op.cc b/tensorflow/core/kernels/split_op.cc
index 4d2100c59c..58e1a73be6 100644
--- a/tensorflow/core/kernels/split_op.cc
+++ b/tensorflow/core/kernels/split_op.cc
@@ -167,11 +167,11 @@ class SplitOpCPU : public SplitOpBase<CPUDevice, T> {
const auto num_threads =
context->device()->tensorflow_cpu_worker_threads()->num_threads;
// TODO(jewillco): Tune heuristic further.
+ const auto input_element_count = input_shape.num_elements();
const bool use_parallelism_between_outputs =
(num_split >= 4 &&
- input_shape.num_elements() >=
- std::max(num_threads, num_split) * 4096 &&
- input_shape.num_elements() < num_split * 180 * 1024);
+ input_element_count >= std::max(num_threads, num_split) * 4096 &&
+ input_element_count < num_split * 180 * 1024);
auto range_output_func = [&indices, context, &output_shape, prefix_dim_size,
split_dim_output_size, suffix_dim_size, &sizes,
@@ -209,7 +209,7 @@ class SplitOpCPU : public SplitOpBase<CPUDevice, T> {
// Run in parallel, disabling parallelism in functor.
Shard(num_split,
context->device()->tensorflow_cpu_worker_threads()->workers,
- num_split, kint64max, range_output_func);
+ num_split, input_element_count / num_split, range_output_func);
} else {
// Run sequentially, but allow internal parallelism in functor.
range_output_func(0, num_split);
diff --git a/tensorflow/core/kernels/split_v_op.cc b/tensorflow/core/kernels/split_v_op.cc
index e2dd66da1e..3316e5fcc9 100644
--- a/tensorflow/core/kernels/split_v_op.cc
+++ b/tensorflow/core/kernels/split_v_op.cc
@@ -225,11 +225,11 @@ class SplitVOpCPU : public SplitVOpBase<CPUDevice, T, Tlen> {
const auto num_threads =
context->device()->tensorflow_cpu_worker_threads()->num_threads;
// TODO(jewillco): Tune heuristic further.
+ const auto input_element_count = input_shape.num_elements();
const bool use_parallelism_between_outputs =
(num_split >= 4 &&
- input_shape.num_elements() >=
- std::max(num_threads, num_split) * 4096 &&
- input_shape.num_elements() < num_split * 180 * 1024);
+ input_element_count >= std::max(num_threads, num_split) * 4096 &&
+ input_element_count < num_split * 180 * 1024);
auto range_output_func = [&indices, context, &input_shape, prefix_dim_size,
split_dim, &split_sizes_vec, &split_start_points,
@@ -267,7 +267,7 @@ class SplitVOpCPU : public SplitVOpBase<CPUDevice, T, Tlen> {
// Run in parallel, disabling parallelism in functor.
Shard(num_split,
context->device()->tensorflow_cpu_worker_threads()->workers,
- num_split, kint64max, range_output_func);
+ num_split, input_element_count / num_split, range_output_func);
} else {
// Run sequentially, but allow internal parallelism in functor.
range_output_func(0, num_split);