aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/grappler/costs/analytical_cost_estimator.cc13
-rw-r--r--tensorflow/core/grappler/costs/analytical_cost_estimator.h9
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.h2
3 files changed, 19 insertions, 5 deletions
diff --git a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc
index 8ecf7be854..7a1e7fcace 100644
--- a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc
@@ -32,7 +32,16 @@ namespace grappler {
AnalyticalCostEstimator::AnalyticalCostEstimator(Cluster* cluster,
bool use_static_shapes)
- : cluster_(cluster), use_static_shapes_(use_static_shapes) {}
+ : cluster_(cluster),
+ node_estimator_(new OpLevelCostEstimator()),
+ use_static_shapes_(use_static_shapes) {}
+
+AnalyticalCostEstimator::AnalyticalCostEstimator(
+ Cluster* cluster, OpLevelCostEstimator* node_estimator,
+ bool use_static_shapes)
+ : cluster_(cluster),
+ node_estimator_(node_estimator),
+ use_static_shapes_(use_static_shapes) {}
Status AnalyticalCostEstimator::Initialize(const GrapplerItem& item) {
item_ = item;
@@ -68,7 +77,7 @@ Status AnalyticalCostEstimator::PredictCosts(const GraphDef& optimized_graph,
auto& op_info = node_info.op_info;
const string& op_name = node_info.name;
- node_costs = node_estimator_.PredictCosts(op_info);
+ node_costs = node_estimator_->PredictCosts(op_info);
if (node_costs.inaccurate) {
inaccurate_nodes.push_back(op_name);
}
diff --git a/tensorflow/core/grappler/costs/analytical_cost_estimator.h b/tensorflow/core/grappler/costs/analytical_cost_estimator.h
index 03e7faa4ff..ef186fc021 100644
--- a/tensorflow/core/grappler/costs/analytical_cost_estimator.h
+++ b/tensorflow/core/grappler/costs/analytical_cost_estimator.h
@@ -37,7 +37,12 @@ struct GrapplerItem;
class AnalyticalCostEstimator : public CostEstimator {
public:
// Does not take ownership of cluster.
- explicit AnalyticalCostEstimator(Cluster* cluster, bool use_static_shapes);
+ AnalyticalCostEstimator(Cluster* cluster, bool use_static_shapes);
+ // Does not take ownership of the cluster, but takes ownership of the
+ // node_estimator
+ AnalyticalCostEstimator(Cluster* cluster,
+ OpLevelCostEstimator* node_estimator,
+ bool use_static_shapes);
~AnalyticalCostEstimator() override {}
// Initalizes the estimator for the specified grappler item.
@@ -53,7 +58,7 @@ class AnalyticalCostEstimator : public CostEstimator {
private:
Cluster* cluster_; // Not owned.
GrapplerItem item_;
- OpLevelCostEstimator node_estimator_;
+ std::unique_ptr<OpLevelCostEstimator> node_estimator_;
bool use_static_shapes_;
};
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
index 266b633922..5e0a97653c 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
@@ -32,7 +32,7 @@ class OpLevelCostEstimator {
OpLevelCostEstimator();
virtual ~OpLevelCostEstimator() {}
- Costs PredictCosts(const OpInfo& op_features) const;
+ virtual Costs PredictCosts(const OpInfo& op_features) const;
protected:
// Returns an estimate of device performance (in billions of operations