aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/op_level_cost_estimator.h
diff options
context:
space:
mode:
authorGravatar Max Galkin <maxgalkin@google.com>2018-03-20 11:45:23 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-20 11:48:34 -0700
commit15d6e8310e1f2ffaa901110903ce7403717b4d2b (patch)
treef7653ce34fc0e4fa36a6554cf1ebe7b4c57cc122 /tensorflow/core/grappler/costs/op_level_cost_estimator.h
parentf57f7d09eeb7402f2455564fafbcebf7ac4b8fe3 (diff)
Improved accuracy of op_level_cost_estimator (QuantizeV2, Dequantize, Gather).
PiperOrigin-RevId: 189779691
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.h')
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.h16
1 files changed, 11 insertions, 5 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
index 7bb530fe31..e5dd31a7a2 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
@@ -51,10 +51,15 @@ class OpLevelCostEstimator {
// Predict cost of an op for which no accurate estimator is defined.
Costs PredictCostOfAnUnknownOp(const OpContext& op_context) const;
- // Naive cost estimate based on operations divided by device ops/sec,
- // and input/output tensor sizes.
- Costs PredictOpCountBasedCost(double operations,
- const OpInfo& op_features) const;
+ // Naive cost estimate based on the given operations count and total
+ // input/output tensor sizes of the given op_info combined.
+ Costs PredictOpCountBasedCost(double operations, const OpInfo& op_info) const;
+
+ // Naive cost estimate based on the given operations count and the given total
+ // io size in bytes. Sizes of op_info inputs and outputs are not taken into
+ // consideration.
+ Costs PredictOpCountBasedCost(double operations, double total_io_bytes,
+ const OpInfo& op_info) const;
// This family of routines counts the number of operations to perform the
// specified TensorFlow Op.
@@ -125,7 +130,7 @@ class OpLevelCostEstimator {
// implementation just divides the operations to
// perform the op (from the "Count" routines,
// above) by the device peak operations per
- // second. Override to supply a better estimate.
+ // second.
// Implementation of costs other than
// execution_time is optional, depending on the
// device.
@@ -139,6 +144,7 @@ class OpLevelCostEstimator {
Costs PredictVariable(const OpContext& op_context) const;
Costs PredictBatchMatMul(const OpContext& op_context) const;
Costs PredictMetadata(const OpContext& op_context) const;
+ Costs PredictGather(const OpContext& op_context) const;
// Utility function for safe division. Returns 0
// if rhs is 0 or negative.