diff options
author | 2018-03-20 11:45:23 -0700 | |
---|---|---|
committer | 2018-03-20 11:48:34 -0700 | |
commit | 15d6e8310e1f2ffaa901110903ce7403717b4d2b (patch) | |
tree | f7653ce34fc0e4fa36a6554cf1ebe7b4c57cc122 /tensorflow/core/grappler/costs/op_level_cost_estimator.h | |
parent | f57f7d09eeb7402f2455564fafbcebf7ac4b8fe3 (diff) |
Improved accuracy of op_level_cost_estimator (QuantizeV2, Dequantize, Gather).
PiperOrigin-RevId: 189779691
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.h')
-rw-r--r-- | tensorflow/core/grappler/costs/op_level_cost_estimator.h | 16 |
1 files changed, 11 insertions, 5 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h index 7bb530fe31..e5dd31a7a2 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h @@ -51,10 +51,15 @@ class OpLevelCostEstimator { // Predict cost of an op for which no accurate estimator is defined. Costs PredictCostOfAnUnknownOp(const OpContext& op_context) const; - // Naive cost estimate based on operations divided by device ops/sec, - // and input/output tensor sizes. - Costs PredictOpCountBasedCost(double operations, - const OpInfo& op_features) const; + // Naive cost estimate based on the given operations count and total + // input/output tensor sizes of the given op_info combined. + Costs PredictOpCountBasedCost(double operations, const OpInfo& op_info) const; + + // Naive cost estimate based on the given operations count and the given total + // io size in bytes. Sizes of op_info inputs and outputs are not taken into + // consideration. + Costs PredictOpCountBasedCost(double operations, double total_io_bytes, + const OpInfo& op_info) const; // This family of routines counts the number of operations to perform the // specified TensorFlow Op. @@ -125,7 +130,7 @@ class OpLevelCostEstimator { // implementation just divides the operations to // perform the op (from the "Count" routines, // above) by the device peak operations per - // second. Override to supply a better estimate. + // second. // Implementation of costs other than // execution_time is optional, depending on the // device. @@ -139,6 +144,7 @@ class OpLevelCostEstimator { Costs PredictVariable(const OpContext& op_context) const; Costs PredictBatchMatMul(const OpContext& op_context) const; Costs PredictMetadata(const OpContext& op_context) const; + Costs PredictGather(const OpContext& op_context) const; // Utility function for safe division. Returns 0 // if rhs is 0 or negative. |