aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
diff options
context:
space:
mode:
authorGravatar Max Galkin <maxgalkin@google.com>2017-10-10 20:22:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-10 20:27:12 -0700
commit00b368966c8c3e003d2a7ddf3c36165185ed0079 (patch)
treeb173ce2669d7953d9bbf0063c9433945b90b67ed /tensorflow/core/grappler/costs/op_level_cost_estimator.cc
parent9885aa8636c51bdd4a155b504b7c8c22bdf22289 (diff)
Minor code cleanup in grappler cost estimation.
PiperOrigin-RevId: 171772766
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.cc')
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc27
1 files changed, 14 insertions, 13 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index b25def7612..7a1295c91e 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -292,21 +292,21 @@ Costs OpLevelCostEstimator::PredictCosts(const OpContext& op_context) const {
return costs;
}
-std::pair<double, double> OpLevelCostEstimator::GetDeviceInfo(
+OpLevelCostEstimator::DeviceInfo OpLevelCostEstimator::GetDeviceInfo(
const DeviceProperties& device) const {
double gflops = -1;
- double bandwidth = -1;
+ double gb_per_sec = -1;
if (device.type() == "CPU") {
// Check if vector instructions are available, and refine performance
// prediction based on this.
// Frequencies are stored in MHz in the DeviceProperties.
gflops = device.num_cores() * device.frequency() * 1e-3;
- if (bandwidth < 0) {
+ if (gb_per_sec < 0) {
if (device.bandwidth() > 0) {
- bandwidth = device.bandwidth() / 1e6;
+ gb_per_sec = device.bandwidth() / 1e6;
} else {
- bandwidth = 32;
+ gb_per_sec = 32;
}
}
} else if (device.type() == "GPU") {
@@ -328,15 +328,15 @@ std::pair<double, double> OpLevelCostEstimator::GetDeviceInfo(
gflops = device.num_cores() * device.frequency() * 1e-3 *
cores_per_multiprocessor * kOpsPerMac;
if (device.bandwidth() > 0) {
- bandwidth = device.bandwidth() / 1e6;
+ gb_per_sec = device.bandwidth() / 1e6;
} else {
- bandwidth = 100;
+ gb_per_sec = 100;
}
}
- VLOG(1) << "Device: " << device.type() << " GFLOPS: " << gflops
- << " Bandwidth: " << bandwidth;
+ VLOG(1) << "Device: " << device.type() << " gflops: " << gflops
+ << " gb_per_sec: " << gb_per_sec;
- return std::make_pair(gflops, bandwidth);
+ return {gflops, gb_per_sec};
}
Costs OpLevelCostEstimator::PredictCwiseOp(const OpContext& op_context) const {
@@ -382,8 +382,8 @@ Costs OpLevelCostEstimator::DummyExecutionTime(
Costs OpLevelCostEstimator::PredictOpCountBasedCost(
double operations, const OpInfo& op_features) const {
- std::pair<double, double> device_perf = GetDeviceInfo(op_features.device());
- Costs::NanoSeconds compute_cost(std::ceil(operations / device_perf.first));
+ DeviceInfo device_perf = GetDeviceInfo(op_features.device());
+ Costs::NanoSeconds compute_cost(std::ceil(operations / device_perf.gigaops));
VLOG(1) << "Op:" << op_features.op() << " GOps:" << operations / 1e9
<< " Execution Time (ns):" << compute_cost.count();
@@ -394,7 +394,8 @@ Costs OpLevelCostEstimator::PredictOpCountBasedCost(
CalculateOutputSize(op_features, &found_unknown_shapes);
double total_io_size = total_input_size + total_output_size;
- Costs::NanoSeconds memory_cost(std::ceil(total_io_size / device_perf.second));
+ Costs::NanoSeconds memory_cost(
+ std::ceil(total_io_size / device_perf.gb_per_sec));
VLOG(1) << "Op:" << op_features.op() << " Size (KB):" << (total_io_size) / 1e3
<< " Memory Time (ns):" << memory_cost.count();