diff options
author | Max Galkin <maxgalkin@google.com> | 2017-10-10 20:22:50 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-10 20:27:12 -0700 |
commit | 00b368966c8c3e003d2a7ddf3c36165185ed0079 (patch) | |
tree | b173ce2669d7953d9bbf0063c9433945b90b67ed /tensorflow/core/grappler/costs/op_level_cost_estimator.cc | |
parent | 9885aa8636c51bdd4a155b504b7c8c22bdf22289 (diff) |
Minor code cleanup in grappler cost estimation.
PiperOrigin-RevId: 171772766
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.cc')
-rw-r--r-- | tensorflow/core/grappler/costs/op_level_cost_estimator.cc | 27 |
1 files changed, 14 insertions, 13 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index b25def7612..7a1295c91e 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -292,21 +292,21 @@ Costs OpLevelCostEstimator::PredictCosts(const OpContext& op_context) const { return costs; } -std::pair<double, double> OpLevelCostEstimator::GetDeviceInfo( +OpLevelCostEstimator::DeviceInfo OpLevelCostEstimator::GetDeviceInfo( const DeviceProperties& device) const { double gflops = -1; - double bandwidth = -1; + double gb_per_sec = -1; if (device.type() == "CPU") { // Check if vector instructions are available, and refine performance // prediction based on this. // Frequencies are stored in MHz in the DeviceProperties. gflops = device.num_cores() * device.frequency() * 1e-3; - if (bandwidth < 0) { + if (gb_per_sec < 0) { if (device.bandwidth() > 0) { - bandwidth = device.bandwidth() / 1e6; + gb_per_sec = device.bandwidth() / 1e6; } else { - bandwidth = 32; + gb_per_sec = 32; } } } else if (device.type() == "GPU") { @@ -328,15 +328,15 @@ std::pair<double, double> OpLevelCostEstimator::GetDeviceInfo( gflops = device.num_cores() * device.frequency() * 1e-3 * cores_per_multiprocessor * kOpsPerMac; if (device.bandwidth() > 0) { - bandwidth = device.bandwidth() / 1e6; + gb_per_sec = device.bandwidth() / 1e6; } else { - bandwidth = 100; + gb_per_sec = 100; } } - VLOG(1) << "Device: " << device.type() << " GFLOPS: " << gflops - << " Bandwidth: " << bandwidth; + VLOG(1) << "Device: " << device.type() << " gflops: " << gflops + << " gb_per_sec: " << gb_per_sec; - return std::make_pair(gflops, bandwidth); + return {gflops, gb_per_sec}; } Costs OpLevelCostEstimator::PredictCwiseOp(const OpContext& op_context) const { @@ -382,8 +382,8 @@ Costs OpLevelCostEstimator::DummyExecutionTime( Costs OpLevelCostEstimator::PredictOpCountBasedCost( double operations, const OpInfo& op_features) const { - std::pair<double, double> device_perf = GetDeviceInfo(op_features.device()); - Costs::NanoSeconds compute_cost(std::ceil(operations / device_perf.first)); + DeviceInfo device_perf = GetDeviceInfo(op_features.device()); + Costs::NanoSeconds compute_cost(std::ceil(operations / device_perf.gigaops)); VLOG(1) << "Op:" << op_features.op() << " GOps:" << operations / 1e9 << " Execution Time (ns):" << compute_cost.count(); @@ -394,7 +394,8 @@ Costs OpLevelCostEstimator::PredictOpCountBasedCost( CalculateOutputSize(op_features, &found_unknown_shapes); double total_io_size = total_input_size + total_output_size; - Costs::NanoSeconds memory_cost(std::ceil(total_io_size / device_perf.second)); + Costs::NanoSeconds memory_cost( + std::ceil(total_io_size / device_perf.gb_per_sec)); VLOG(1) << "Op:" << op_features.op() << " Size (KB):" << (total_io_size) / 1e3 << " Memory Time (ns):" << memory_cost.count(); |