From 00b368966c8c3e003d2a7ddf3c36165185ed0079 Mon Sep 17 00:00:00 2001 From: Max Galkin Date: Tue, 10 Oct 2017 20:22:50 -0700 Subject: Minor code cleanup in grappler cost estimation. PiperOrigin-RevId: 171772766 --- .../core/grappler/costs/op_level_cost_estimator.cc | 27 +++++++++++----------- .../core/grappler/costs/op_level_cost_estimator.h | 13 +++++++---- 2 files changed, 22 insertions(+), 18 deletions(-) (limited to 'tensorflow/core') diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index b25def7612..7a1295c91e 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -292,21 +292,21 @@ Costs OpLevelCostEstimator::PredictCosts(const OpContext& op_context) const { return costs; } -std::pair OpLevelCostEstimator::GetDeviceInfo( +OpLevelCostEstimator::DeviceInfo OpLevelCostEstimator::GetDeviceInfo( const DeviceProperties& device) const { double gflops = -1; - double bandwidth = -1; + double gb_per_sec = -1; if (device.type() == "CPU") { // Check if vector instructions are available, and refine performance // prediction based on this. // Frequencies are stored in MHz in the DeviceProperties. gflops = device.num_cores() * device.frequency() * 1e-3; - if (bandwidth < 0) { + if (gb_per_sec < 0) { if (device.bandwidth() > 0) { - bandwidth = device.bandwidth() / 1e6; + gb_per_sec = device.bandwidth() / 1e6; } else { - bandwidth = 32; + gb_per_sec = 32; } } } else if (device.type() == "GPU") { @@ -328,15 +328,15 @@ std::pair OpLevelCostEstimator::GetDeviceInfo( gflops = device.num_cores() * device.frequency() * 1e-3 * cores_per_multiprocessor * kOpsPerMac; if (device.bandwidth() > 0) { - bandwidth = device.bandwidth() / 1e6; + gb_per_sec = device.bandwidth() / 1e6; } else { - bandwidth = 100; + gb_per_sec = 100; } } - VLOG(1) << "Device: " << device.type() << " GFLOPS: " << gflops - << " Bandwidth: " << bandwidth; + VLOG(1) << "Device: " << device.type() << " gflops: " << gflops + << " gb_per_sec: " << gb_per_sec; - return std::make_pair(gflops, bandwidth); + return {gflops, gb_per_sec}; } Costs OpLevelCostEstimator::PredictCwiseOp(const OpContext& op_context) const { @@ -382,8 +382,8 @@ Costs OpLevelCostEstimator::DummyExecutionTime( Costs OpLevelCostEstimator::PredictOpCountBasedCost( double operations, const OpInfo& op_features) const { - std::pair device_perf = GetDeviceInfo(op_features.device()); - Costs::NanoSeconds compute_cost(std::ceil(operations / device_perf.first)); + DeviceInfo device_perf = GetDeviceInfo(op_features.device()); + Costs::NanoSeconds compute_cost(std::ceil(operations / device_perf.gigaops)); VLOG(1) << "Op:" << op_features.op() << " GOps:" << operations / 1e9 << " Execution Time (ns):" << compute_cost.count(); @@ -394,7 +394,8 @@ Costs OpLevelCostEstimator::PredictOpCountBasedCost( CalculateOutputSize(op_features, &found_unknown_shapes); double total_io_size = total_input_size + total_output_size; - Costs::NanoSeconds memory_cost(std::ceil(total_io_size / device_perf.second)); + Costs::NanoSeconds memory_cost( + std::ceil(total_io_size / device_perf.gb_per_sec)); VLOG(1) << "Op:" << op_features.op() << " Size (KB):" << (total_io_size) / 1e3 << " Memory Time (ns):" << memory_cost.count(); diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h index 0e63299bcb..3a8385dd73 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h @@ -36,11 +36,14 @@ class OpLevelCostEstimator { virtual Costs PredictCosts(const OpContext& op_context) const; protected: - // Returns an estimate of device performance (in billions of operations - // executed per second) and memory bandwidth (in GigaBytes/second) for the - // specified device. - virtual std::pair GetDeviceInfo( - const DeviceProperties& device) const; + // Basic device performance info, sufficient for roofline estimate. + struct DeviceInfo { + double gigaops; // Billions of operations executed per second. + double gb_per_sec; // Bandwidth to main memory in GB per second. + }; + + // Returns basic device performance info. + virtual DeviceInfo GetDeviceInfo(const DeviceProperties& device) const; // For operations for which we haven't yet built estimates, returns a dummy // value based on input size. -- cgit v1.2.3