aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/op_level_cost_estimator.h
diff options
context:
space:
mode:
authorGravatar Max Galkin <maxgalkin@google.com>2017-10-10 20:22:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-10 20:27:12 -0700
commit00b368966c8c3e003d2a7ddf3c36165185ed0079 (patch)
treeb173ce2669d7953d9bbf0063c9433945b90b67ed /tensorflow/core/grappler/costs/op_level_cost_estimator.h
parent9885aa8636c51bdd4a155b504b7c8c22bdf22289 (diff)
Minor code cleanup in grappler cost estimation.
PiperOrigin-RevId: 171772766
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.h')
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.h13
1 files changed, 8 insertions, 5 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
index 0e63299bcb..3a8385dd73 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
@@ -36,11 +36,14 @@ class OpLevelCostEstimator {
virtual Costs PredictCosts(const OpContext& op_context) const;
protected:
- // Returns an estimate of device performance (in billions of operations
- // executed per second) and memory bandwidth (in GigaBytes/second) for the
- // specified device.
- virtual std::pair<double, double> GetDeviceInfo(
- const DeviceProperties& device) const;
+ // Basic device performance info, sufficient for roofline estimate.
+ struct DeviceInfo {
+ double gigaops; // Billions of operations executed per second.
+ double gb_per_sec; // Bandwidth to main memory in GB per second.
+ };
+
+ // Returns basic device performance info.
+ virtual DeviceInfo GetDeviceInfo(const DeviceProperties& device) const;
// For operations for which we haven't yet built estimates, returns a dummy
// value based on input size.