aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-26 22:55:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-26 22:58:44 -0700
commit40dee372e3ee844c4746baa914c07b9c582a2ce7 (patch)
treebd39a01c0aad8a6cfc8e5d4205674b5a8892133d /tensorflow/core/grappler/costs/op_level_cost_estimator.cc
parent680c2f5d988fb1f3b725fb8f0a67d1926be8169b (diff)
Define OpContext and use it for OpLevelCostEstimator.
This CL does not add any functionality (except GraphDef's function library pointer is passed to OpContext), but we can later add additional fields to OpContext struct for extending VirtualCluster, Scheduler, Placer, and others. PiperOrigin-RevId: 170157235
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.cc')
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc47
1 files changed, 29 insertions, 18 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index fbafed7c1f..b25def7612 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -142,10 +142,12 @@ int64 CwiseOutputElementCount(const TensorShapeProto& input_shape_1,
OpLevelCostEstimator::OpLevelCostEstimator() {
// Syntactic sugar to build and return a lambda that takes an OpInfo and
// returns a cost.
- typedef Costs (OpLevelCostEstimator::*CostImpl)(const OpInfo& op_feature)
+ typedef Costs (OpLevelCostEstimator::*CostImpl)(const OpContext& op_context)
const;
- auto wrap = [this](CostImpl impl) -> std::function<Costs(const OpInfo&)> {
- return [this, impl](const OpInfo& op) { return (this->*impl)(op); };
+ auto wrap = [this](CostImpl impl) -> std::function<Costs(const OpContext&)> {
+ return [this, impl](const OpContext& op_context) {
+ return (this->*impl)(op_context);
+ };
};
device_cost_impl_ = {
@@ -272,18 +274,19 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
compute_memory_overlap_ = false;
}
-Costs OpLevelCostEstimator::PredictCosts(const OpInfo& op_features) const {
+Costs OpLevelCostEstimator::PredictCosts(const OpContext& op_context) const {
+ const auto& op_features = op_context.op_info;
auto it = device_cost_impl_.find(op_features.op());
if (it == device_cost_impl_.end()) {
if (elementwise_ops_.find(op_features.op()) != elementwise_ops_.end()) {
- return PredictCwiseOp(op_features);
+ return PredictCwiseOp(op_context);
}
VLOG(1) << "Missing implementation for op: " << op_features.op();
- return DummyExecutionTime(op_features);
+ return DummyExecutionTime(op_context);
}
- std::function<Costs(const OpInfo&)> estimator = it->second;
- Costs costs = estimator(op_features);
+ std::function<Costs(const OpContext&)> estimator = it->second;
+ Costs costs = estimator(op_context);
VLOG(1) << "Operation " << op_features.op() << " takes "
<< costs.execution_time.count() << " ns.";
return costs;
@@ -336,7 +339,8 @@ std::pair<double, double> OpLevelCostEstimator::GetDeviceInfo(
return std::make_pair(gflops, bandwidth);
}
-Costs OpLevelCostEstimator::PredictCwiseOp(const OpInfo& op_features) const {
+Costs OpLevelCostEstimator::PredictCwiseOp(const OpContext& op_context) const {
+ const auto& op_features = op_context.op_info;
bool found_unknown_shapes = false;
// For unary or binary element-wise operations, op count is the element count
// of any input. We use the count for the largest input here to be more robust
@@ -369,9 +373,9 @@ Costs OpLevelCostEstimator::PredictCwiseOp(const OpInfo& op_features) const {
}
Costs OpLevelCostEstimator::DummyExecutionTime(
- const OpInfo& op_features) const {
+ const OpContext& op_context) const {
// Use CwiseOp time as an estimation
- auto costs = PredictCwiseOp(op_features);
+ auto costs = PredictCwiseOp(op_context);
costs.inaccurate = true;
return costs;
}
@@ -806,7 +810,8 @@ int64 OpLevelCostEstimator::CalculateOutputSize(
return total_output_size;
}
-Costs OpLevelCostEstimator::PredictConv2D(const OpInfo& op_features) const {
+Costs OpLevelCostEstimator::PredictConv2D(const OpContext& op_context) const {
+ const auto& op_features = op_context.op_info;
bool found_unknown_shapes = false;
auto costs = PredictOpCountBasedCost(
CountConv2DOperations(op_features, &found_unknown_shapes), op_features);
@@ -815,7 +820,8 @@ Costs OpLevelCostEstimator::PredictConv2D(const OpInfo& op_features) const {
}
Costs OpLevelCostEstimator::PredictConv2DBackpropInput(
- const OpInfo& op_features) const {
+ const OpContext& op_context) const {
+ const auto& op_features = op_context.op_info;
bool found_unknown_shapes = false;
auto costs =
PredictOpCountBasedCost(CountConv2DBackpropInputOperations(
@@ -826,7 +832,8 @@ Costs OpLevelCostEstimator::PredictConv2DBackpropInput(
}
Costs OpLevelCostEstimator::PredictConv2DBackpropFilter(
- const OpInfo& op_features) const {
+ const OpContext& op_context) const {
+ const auto& op_features = op_context.op_info;
bool found_unknown_shapes = false;
auto costs =
PredictOpCountBasedCost(CountConv2DBackpropFilterOperations(
@@ -836,7 +843,8 @@ Costs OpLevelCostEstimator::PredictConv2DBackpropFilter(
return costs;
}
-Costs OpLevelCostEstimator::PredictMatMul(const OpInfo& op_features) const {
+Costs OpLevelCostEstimator::PredictMatMul(const OpContext& op_context) const {
+ const auto& op_features = op_context.op_info;
bool found_unknown_shapes = false;
auto costs = PredictOpCountBasedCost(
CountMatMulOperations(op_features, &found_unknown_shapes), op_features);
@@ -844,13 +852,15 @@ Costs OpLevelCostEstimator::PredictMatMul(const OpInfo& op_features) const {
return costs;
}
-Costs OpLevelCostEstimator::PredictNoOp(const OpInfo& op_features) const {
+Costs OpLevelCostEstimator::PredictNoOp(const OpContext& op_context) const {
+ const auto& op_features = op_context.op_info;
VLOG(1) << "Op:" << op_features.op() << " Execution Time 0 (ns)";
return Costs::ZeroCosts();
}
Costs OpLevelCostEstimator::PredictBatchMatMul(
- const OpInfo& op_features) const {
+ const OpContext& op_context) const {
+ const auto& op_features = op_context.op_info;
bool found_unknown_shapes = false;
Costs costs = PredictOpCountBasedCost(
CountBatchMatMulOperations(op_features, &found_unknown_shapes),
@@ -859,7 +869,8 @@ Costs OpLevelCostEstimator::PredictBatchMatMul(
return costs;
}
-Costs OpLevelCostEstimator::PredictMetadata(const OpInfo& op_features) const {
+Costs OpLevelCostEstimator::PredictMetadata(const OpContext& op_context) const {
+ const auto& op_features = op_context.op_info;
Costs costs;
costs.max_memory = CalculateOutputSize(op_features, &costs.inaccurate);
// Metadata operations are so cheap we assume they take the minimum amount of