diff options
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.h')
-rw-r--r-- | tensorflow/core/grappler/costs/op_level_cost_estimator.h | 23 |
1 files changed, 12 insertions, 11 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h index b4302dc9e1..0e63299bcb 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h @@ -21,6 +21,7 @@ limitations under the License. #include <string> #include "tensorflow/core/grappler/costs/cost_estimator.h" +#include "tensorflow/core/grappler/costs/op_context.h" #include "tensorflow/core/grappler/costs/op_performance_data.pb.h" #include "tensorflow/core/util/padding.h" @@ -32,7 +33,7 @@ class OpLevelCostEstimator { OpLevelCostEstimator(); virtual ~OpLevelCostEstimator() {} - virtual Costs PredictCosts(const OpInfo& op_features) const; + virtual Costs PredictCosts(const OpContext& op_context) const; protected: // Returns an estimate of device performance (in billions of operations @@ -43,7 +44,7 @@ class OpLevelCostEstimator { // For operations for which we haven't yet built estimates, returns a dummy // value based on input size. - Costs DummyExecutionTime(const OpInfo& op_features) const; + Costs DummyExecutionTime(const OpContext& op_context) const; // Naive cost estimate based on operations divided by device ops/sec. Costs PredictOpCountBasedCost(double operations, @@ -122,14 +123,14 @@ class OpLevelCostEstimator { // Implementation of costs other than // execution_time is optional, depending on the // device. - Costs PredictConv2D(const OpInfo& op_features) const; - Costs PredictCwiseOp(const OpInfo& op_features) const; - Costs PredictConv2DBackpropInput(const OpInfo& op_features) const; - Costs PredictConv2DBackpropFilter(const OpInfo& op_features) const; - Costs PredictMatMul(const OpInfo& op_features) const; - Costs PredictNoOp(const OpInfo& op_features) const; - Costs PredictBatchMatMul(const OpInfo& op_features) const; - Costs PredictMetadata(const OpInfo& op_features) const; + Costs PredictConv2D(const OpContext& op_context) const; + Costs PredictCwiseOp(const OpContext& op_context) const; + Costs PredictConv2DBackpropInput(const OpContext& op_context) const; + Costs PredictConv2DBackpropFilter(const OpContext& op_context) const; + Costs PredictMatMul(const OpContext& op_context) const; + Costs PredictNoOp(const OpContext& op_context) const; + Costs PredictBatchMatMul(const OpContext& op_context) const; + Costs PredictMetadata(const OpContext& op_context) const; // Utility function for safe division. Returns 0 // if rhs is 0 or negative. @@ -148,7 +149,7 @@ class OpLevelCostEstimator { protected: std::map<string, int> elementwise_ops_; - typedef std::function<Costs(const OpInfo& op_feature)> CostImpl; + typedef std::function<Costs(const OpContext& op_context)> CostImpl; std::map<string, CostImpl> device_cost_impl_; // If true, assume compute and memory overlap; hence, the op cost is max of // compute_time and memory_time, insteaf of sum of those two. |