diff options
3 files changed, 19 insertions, 5 deletions
diff --git a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc index 8ecf7be854..7a1e7fcace 100644 --- a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc @@ -32,7 +32,16 @@ namespace grappler { AnalyticalCostEstimator::AnalyticalCostEstimator(Cluster* cluster, bool use_static_shapes) - : cluster_(cluster), use_static_shapes_(use_static_shapes) {} + : cluster_(cluster), + node_estimator_(new OpLevelCostEstimator()), + use_static_shapes_(use_static_shapes) {} + +AnalyticalCostEstimator::AnalyticalCostEstimator( + Cluster* cluster, OpLevelCostEstimator* node_estimator, + bool use_static_shapes) + : cluster_(cluster), + node_estimator_(node_estimator), + use_static_shapes_(use_static_shapes) {} Status AnalyticalCostEstimator::Initialize(const GrapplerItem& item) { item_ = item; @@ -68,7 +77,7 @@ Status AnalyticalCostEstimator::PredictCosts(const GraphDef& optimized_graph, auto& op_info = node_info.op_info; const string& op_name = node_info.name; - node_costs = node_estimator_.PredictCosts(op_info); + node_costs = node_estimator_->PredictCosts(op_info); if (node_costs.inaccurate) { inaccurate_nodes.push_back(op_name); } diff --git a/tensorflow/core/grappler/costs/analytical_cost_estimator.h b/tensorflow/core/grappler/costs/analytical_cost_estimator.h index 03e7faa4ff..ef186fc021 100644 --- a/tensorflow/core/grappler/costs/analytical_cost_estimator.h +++ b/tensorflow/core/grappler/costs/analytical_cost_estimator.h @@ -37,7 +37,12 @@ struct GrapplerItem; class AnalyticalCostEstimator : public CostEstimator { public: // Does not take ownership of cluster. - explicit AnalyticalCostEstimator(Cluster* cluster, bool use_static_shapes); + AnalyticalCostEstimator(Cluster* cluster, bool use_static_shapes); + // Does not take ownership of the cluster, but takes ownership of the + // node_estimator + AnalyticalCostEstimator(Cluster* cluster, + OpLevelCostEstimator* node_estimator, + bool use_static_shapes); ~AnalyticalCostEstimator() override {} // Initalizes the estimator for the specified grappler item. @@ -53,7 +58,7 @@ class AnalyticalCostEstimator : public CostEstimator { private: Cluster* cluster_; // Not owned. GrapplerItem item_; - OpLevelCostEstimator node_estimator_; + std::unique_ptr<OpLevelCostEstimator> node_estimator_; bool use_static_shapes_; }; diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h index 266b633922..5e0a97653c 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h @@ -32,7 +32,7 @@ class OpLevelCostEstimator { OpLevelCostEstimator(); virtual ~OpLevelCostEstimator() {} - Costs PredictCosts(const OpInfo& op_features) const; + virtual Costs PredictCosts(const OpInfo& op_features) const; protected: // Returns an estimate of device performance (in billions of operations |