aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/random_op.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-05-24 13:14:03 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-24 14:39:40 -0700
commit36694792612675ae16206bc3f4dcf5fec328085f (patch)
tree24e74cb3f12b990d74571c92beb42266b6099aa4 /tensorflow/core/kernels/random_op.cc
parent67ddfa5b34867691cc68708930c75c806e4750e2 (diff)
Remove heuristic caps on parallelism that should now be handled by cost model.
Adjust cost model for FloatToBFloat16 and BFloat16ToFloat. They do not take 100 cycles per element. This cl is a companion to cl/122779011, which makes the caps effective again, even with the nonblocking threadpool. Change: 123144919
Diffstat (limited to 'tensorflow/core/kernels/random_op.cc')
-rw-r--r--tensorflow/core/kernels/random_op.cc12
1 files changed, 6 insertions, 6 deletions
diff --git a/tensorflow/core/kernels/random_op.cc b/tensorflow/core/kernels/random_op.cc
index bd70663eb3..670a041e18 100644
--- a/tensorflow/core/kernels/random_op.cc
+++ b/tensorflow/core/kernels/random_op.cc
@@ -161,13 +161,11 @@ struct FillPhiloxRandom<CPUDevice, Distribution> {
int64 total_group_count = (size + kGroupSize - 1) / kGroupSize;
- // Limit to maximum six threads for now. The performance scaling is very
- // sub-linear. Too many threads causes a much worse overall performance.
- int num_workers = 6;
const int kGroupCost =
random::PhiloxRandom::kResultElementCount *
(random::PhiloxRandom::kElementCost + Distribution::kElementCost);
- Shard(num_workers, worker_threads.workers, total_group_count, kGroupCost,
+ Shard(worker_threads.num_threads, worker_threads.workers, total_group_count,
+ kGroupCost,
[&gen, data, size, dist](int64 start_group, int64 limit_group) {
FillPhiloxRandomTask<
Distribution,
@@ -399,8 +397,10 @@ class MultinomialOp : public OpKernel {
sizeof(int64) * num_samples);
}
};
- Shard(std::min(batch_size, worker_threads.num_threads),
- worker_threads.workers, batch_size, num_samples * num_classes * 2,
+ // Rough estimate, log2() takes from 58-680 cycles on Haswell.
+ // The functor here calls log twice for each element.
+ const int64 cost = 500 * num_samples * num_classes;
+ Shard(worker_threads.num_threads, worker_threads.workers, batch_size, cost,
DoWork);
}