aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/op_level_cost_estimator.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-26 22:55:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-26 22:58:44 -0700
commit40dee372e3ee844c4746baa914c07b9c582a2ce7 (patch)
treebd39a01c0aad8a6cfc8e5d4205674b5a8892133d /tensorflow/core/grappler/costs/op_level_cost_estimator.h
parent680c2f5d988fb1f3b725fb8f0a67d1926be8169b (diff)
Define OpContext and use it for OpLevelCostEstimator.
This CL does not add any functionality (except GraphDef's function library pointer is passed to OpContext), but we can later add additional fields to OpContext struct for extending VirtualCluster, Scheduler, Placer, and others. PiperOrigin-RevId: 170157235
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.h')
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.h23
1 files changed, 12 insertions, 11 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
index b4302dc9e1..0e63299bcb 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
@@ -21,6 +21,7 @@ limitations under the License.
#include <string>
#include "tensorflow/core/grappler/costs/cost_estimator.h"
+#include "tensorflow/core/grappler/costs/op_context.h"
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
#include "tensorflow/core/util/padding.h"
@@ -32,7 +33,7 @@ class OpLevelCostEstimator {
OpLevelCostEstimator();
virtual ~OpLevelCostEstimator() {}
- virtual Costs PredictCosts(const OpInfo& op_features) const;
+ virtual Costs PredictCosts(const OpContext& op_context) const;
protected:
// Returns an estimate of device performance (in billions of operations
@@ -43,7 +44,7 @@ class OpLevelCostEstimator {
// For operations for which we haven't yet built estimates, returns a dummy
// value based on input size.
- Costs DummyExecutionTime(const OpInfo& op_features) const;
+ Costs DummyExecutionTime(const OpContext& op_context) const;
// Naive cost estimate based on operations divided by device ops/sec.
Costs PredictOpCountBasedCost(double operations,
@@ -122,14 +123,14 @@ class OpLevelCostEstimator {
// Implementation of costs other than
// execution_time is optional, depending on the
// device.
- Costs PredictConv2D(const OpInfo& op_features) const;
- Costs PredictCwiseOp(const OpInfo& op_features) const;
- Costs PredictConv2DBackpropInput(const OpInfo& op_features) const;
- Costs PredictConv2DBackpropFilter(const OpInfo& op_features) const;
- Costs PredictMatMul(const OpInfo& op_features) const;
- Costs PredictNoOp(const OpInfo& op_features) const;
- Costs PredictBatchMatMul(const OpInfo& op_features) const;
- Costs PredictMetadata(const OpInfo& op_features) const;
+ Costs PredictConv2D(const OpContext& op_context) const;
+ Costs PredictCwiseOp(const OpContext& op_context) const;
+ Costs PredictConv2DBackpropInput(const OpContext& op_context) const;
+ Costs PredictConv2DBackpropFilter(const OpContext& op_context) const;
+ Costs PredictMatMul(const OpContext& op_context) const;
+ Costs PredictNoOp(const OpContext& op_context) const;
+ Costs PredictBatchMatMul(const OpContext& op_context) const;
+ Costs PredictMetadata(const OpContext& op_context) const;
// Utility function for safe division. Returns 0
// if rhs is 0 or negative.
@@ -148,7 +149,7 @@ class OpLevelCostEstimator {
protected:
std::map<string, int> elementwise_ops_;
- typedef std::function<Costs(const OpInfo& op_feature)> CostImpl;
+ typedef std::function<Costs(const OpContext& op_context)> CostImpl;
std::map<string, CostImpl> device_cost_impl_;
// If true, assume compute and memory overlap; hence, the op cost is max of
// compute_time and memory_time, insteaf of sum of those two.