diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-10-18 12:22:04 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-18 12:25:47 -0700 |
commit | 192f1c24ec6692342391c03bb620f5de1af9de3b (patch) | |
tree | 5fca708751bf3ec66db86317115eee7d04a52ef7 | |
parent | f1603b7893f922dfe64244c6bae9b93d7d594437 (diff) |
Fixed work size computation in Split and SplitV ops to avoid integer overflow.
PiperOrigin-RevId: 172637818
-rw-r--r-- | tensorflow/core/kernels/split_op.cc | 8 | ||||
-rw-r--r-- | tensorflow/core/kernels/split_v_op.cc | 8 |
2 files changed, 8 insertions, 8 deletions
diff --git a/tensorflow/core/kernels/split_op.cc b/tensorflow/core/kernels/split_op.cc index 4d2100c59c..58e1a73be6 100644 --- a/tensorflow/core/kernels/split_op.cc +++ b/tensorflow/core/kernels/split_op.cc @@ -167,11 +167,11 @@ class SplitOpCPU : public SplitOpBase<CPUDevice, T> { const auto num_threads = context->device()->tensorflow_cpu_worker_threads()->num_threads; // TODO(jewillco): Tune heuristic further. + const auto input_element_count = input_shape.num_elements(); const bool use_parallelism_between_outputs = (num_split >= 4 && - input_shape.num_elements() >= - std::max(num_threads, num_split) * 4096 && - input_shape.num_elements() < num_split * 180 * 1024); + input_element_count >= std::max(num_threads, num_split) * 4096 && + input_element_count < num_split * 180 * 1024); auto range_output_func = [&indices, context, &output_shape, prefix_dim_size, split_dim_output_size, suffix_dim_size, &sizes, @@ -209,7 +209,7 @@ class SplitOpCPU : public SplitOpBase<CPUDevice, T> { // Run in parallel, disabling parallelism in functor. Shard(num_split, context->device()->tensorflow_cpu_worker_threads()->workers, - num_split, kint64max, range_output_func); + num_split, input_element_count / num_split, range_output_func); } else { // Run sequentially, but allow internal parallelism in functor. range_output_func(0, num_split); diff --git a/tensorflow/core/kernels/split_v_op.cc b/tensorflow/core/kernels/split_v_op.cc index e2dd66da1e..3316e5fcc9 100644 --- a/tensorflow/core/kernels/split_v_op.cc +++ b/tensorflow/core/kernels/split_v_op.cc @@ -225,11 +225,11 @@ class SplitVOpCPU : public SplitVOpBase<CPUDevice, T, Tlen> { const auto num_threads = context->device()->tensorflow_cpu_worker_threads()->num_threads; // TODO(jewillco): Tune heuristic further. + const auto input_element_count = input_shape.num_elements(); const bool use_parallelism_between_outputs = (num_split >= 4 && - input_shape.num_elements() >= - std::max(num_threads, num_split) * 4096 && - input_shape.num_elements() < num_split * 180 * 1024); + input_element_count >= std::max(num_threads, num_split) * 4096 && + input_element_count < num_split * 180 * 1024); auto range_output_func = [&indices, context, &input_shape, prefix_dim_size, split_dim, &split_sizes_vec, &split_start_points, @@ -267,7 +267,7 @@ class SplitVOpCPU : public SplitVOpBase<CPUDevice, T, Tlen> { // Run in parallel, disabling parallelism in functor. Shard(num_split, context->device()->tensorflow_cpu_worker_threads()->workers, - num_split, kint64max, range_output_func); + num_split, input_element_count / num_split, range_output_func); } else { // Run sequentially, but allow internal parallelism in functor. range_output_func(0, num_split); |