aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/op_level_cost_estimator.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.h')
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.h23
1 files changed, 12 insertions, 11 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
index b4302dc9e1..0e63299bcb 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
@@ -21,6 +21,7 @@ limitations under the License.
#include <string>
#include "tensorflow/core/grappler/costs/cost_estimator.h"
+#include "tensorflow/core/grappler/costs/op_context.h"
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
#include "tensorflow/core/util/padding.h"
@@ -32,7 +33,7 @@ class OpLevelCostEstimator {
OpLevelCostEstimator();
virtual ~OpLevelCostEstimator() {}
- virtual Costs PredictCosts(const OpInfo& op_features) const;
+ virtual Costs PredictCosts(const OpContext& op_context) const;
protected:
// Returns an estimate of device performance (in billions of operations
@@ -43,7 +44,7 @@ class OpLevelCostEstimator {
// For operations for which we haven't yet built estimates, returns a dummy
// value based on input size.
- Costs DummyExecutionTime(const OpInfo& op_features) const;
+ Costs DummyExecutionTime(const OpContext& op_context) const;
// Naive cost estimate based on operations divided by device ops/sec.
Costs PredictOpCountBasedCost(double operations,
@@ -122,14 +123,14 @@ class OpLevelCostEstimator {
// Implementation of costs other than
// execution_time is optional, depending on the
// device.
- Costs PredictConv2D(const OpInfo& op_features) const;
- Costs PredictCwiseOp(const OpInfo& op_features) const;
- Costs PredictConv2DBackpropInput(const OpInfo& op_features) const;
- Costs PredictConv2DBackpropFilter(const OpInfo& op_features) const;
- Costs PredictMatMul(const OpInfo& op_features) const;
- Costs PredictNoOp(const OpInfo& op_features) const;
- Costs PredictBatchMatMul(const OpInfo& op_features) const;
- Costs PredictMetadata(const OpInfo& op_features) const;
+ Costs PredictConv2D(const OpContext& op_context) const;
+ Costs PredictCwiseOp(const OpContext& op_context) const;
+ Costs PredictConv2DBackpropInput(const OpContext& op_context) const;
+ Costs PredictConv2DBackpropFilter(const OpContext& op_context) const;
+ Costs PredictMatMul(const OpContext& op_context) const;
+ Costs PredictNoOp(const OpContext& op_context) const;
+ Costs PredictBatchMatMul(const OpContext& op_context) const;
+ Costs PredictMetadata(const OpContext& op_context) const;
// Utility function for safe division. Returns 0
// if rhs is 0 or negative.
@@ -148,7 +149,7 @@ class OpLevelCostEstimator {
protected:
std::map<string, int> elementwise_ops_;
- typedef std::function<Costs(const OpInfo& op_feature)> CostImpl;
+ typedef std::function<Costs(const OpContext& op_context)> CostImpl;
std::map<string, CostImpl> device_cost_impl_;
// If true, assume compute and memory overlap; hence, the op cost is max of
// compute_time and memory_time, insteaf of sum of those two.