aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
diff options
context:
space:
mode:
authorGravatar Peter Ma <pcma@google.com>2018-08-10 15:03:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-10 15:13:09 -0700
commit2625345c727b14f8e770d4f980fe86e9ccc8b03d (patch)
treeba456d8a49c073e56c8f78d26bc17f413937fff0 /tensorflow/core/grappler/costs/op_level_cost_estimator.cc
parent83a1435684149e381521de528c3af40daa784570 (diff)
Add two counters in Costs Struct for number of ops processed/predicted in total, and number of ops predicted with unknown shapes
PiperOrigin-RevId: 208274158
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.cc')
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc23
1 files changed, 23 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index 5b303f6ccb..6406a4bdbf 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -449,6 +449,7 @@ Costs OpLevelCostEstimator::PredictCwiseOp(const OpContext& op_context) const {
if (found_unknown_shapes || !is_known_elementwise_op) {
costs.inaccurate = true;
}
+ costs.num_ops_with_unknown_shapes = found_unknown_shapes;
return costs;
}
@@ -469,6 +470,7 @@ Costs OpLevelCostEstimator::PredictOpCountBasedCost(
const double total_io_bytes = input_size + output_size;
Costs costs = PredictOpCountBasedCost(operations, total_io_bytes, op_info);
costs.inaccurate = unknown_shapes;
+ costs.num_ops_with_unknown_shapes = unknown_shapes;
costs.max_memory = output_size;
return costs;
}
@@ -627,6 +629,7 @@ int64 OpLevelCostEstimator::CountMatMulOperations(
if (op_features.inputs_size() < 2) {
LOG(ERROR) << "Need 2 inputs but got " << op_features.inputs_size();
+ // TODO(pcma): Try to separate invalid inputs from unknown shapes
*found_unknown_shapes = true;
return 0;
}
@@ -694,11 +697,13 @@ int64 OpLevelCostEstimator::CountBatchMatMulOperations(
const OpInfo& op_features, bool* found_unknown_shapes) const {
if (op_features.op() != kBatchMatMul) {
LOG(ERROR) << "Invalid Operation: " << op_features.op();
+ // TODO(pcma): Try to separate invalid inputs from unknown shapes
*found_unknown_shapes = true;
return 0;
}
if (op_features.inputs_size() != 2) {
LOG(ERROR) << "Expected 2 inputs but got " << op_features.inputs_size();
+ // TODO(pcma): Try to separate invalid inputs from unknown shapes
*found_unknown_shapes = true;
return 0;
}
@@ -858,6 +863,7 @@ int64 OpLevelCostEstimator::CountConv2DBackpropInputOperations(
"kDepthwiseConv2dNativeBackpropInput";
if (op_features.inputs_size() < 2) {
+ // TODO(pcma): Try to separate invalid inputs from unknown shapes
*found_unknown_shapes = true;
return ops;
}
@@ -935,6 +941,7 @@ int64 OpLevelCostEstimator::CountConv2DBackpropFilterOperations(
}
if (op_features.inputs_size() < 1) {
+ // TODO(pcma): Try to separate invalid inputs from unknown shapes
*found_unknown_shapes = true;
return ops;
}
@@ -1037,6 +1044,7 @@ Costs OpLevelCostEstimator::PredictConv2D(const OpContext& op_context) const {
auto costs = PredictOpCountBasedCost(
CountConv2DOperations(op_features, &found_unknown_shapes), op_features);
costs.inaccurate = found_unknown_shapes;
+ costs.num_ops_with_unknown_shapes = found_unknown_shapes;
return costs;
}
@@ -1049,6 +1057,7 @@ Costs OpLevelCostEstimator::PredictConv2DBackpropInput(
op_features, nullptr, &found_unknown_shapes),
op_features);
costs.inaccurate = found_unknown_shapes;
+ costs.num_ops_with_unknown_shapes = found_unknown_shapes;
return costs;
}
@@ -1061,6 +1070,7 @@ Costs OpLevelCostEstimator::PredictConv2DBackpropFilter(
op_features, nullptr, &found_unknown_shapes),
op_features);
costs.inaccurate = found_unknown_shapes;
+ costs.num_ops_with_unknown_shapes = found_unknown_shapes;
return costs;
}
@@ -1148,6 +1158,7 @@ Costs OpLevelCostEstimator::PredictFusedConv2DBiasActivation(
// Construct component operations and run the cost computation.
auto costs = PredictFusedOp(op_context_with_output, component_ops);
costs.inaccurate |= found_unknown_shapes;
+ costs.num_ops_with_unknown_shapes = costs.inaccurate;
return costs;
}
@@ -1157,6 +1168,7 @@ Costs OpLevelCostEstimator::PredictMatMul(const OpContext& op_context) const {
auto costs = PredictOpCountBasedCost(
CountMatMulOperations(op_features, &found_unknown_shapes), op_features);
costs.inaccurate = found_unknown_shapes;
+ costs.num_ops_with_unknown_shapes = found_unknown_shapes;
return costs;
}
@@ -1171,6 +1183,7 @@ Costs OpLevelCostEstimator::PredictIdentity(const OpContext& op_context) const {
VLOG(1) << "Op:" << op_features.op() << " Execution Time 0 (ns)";
Costs result = Costs::ZeroCosts();
result.max_memory = CalculateOutputSize(op_features, &result.inaccurate);
+ result.num_ops_with_unknown_shapes = result.inaccurate;
// Assign the minimum amount of time we can represent to the identity op since
// it tends to be really cheap.
result.compute_time = kMinComputeTime;
@@ -1184,6 +1197,7 @@ Costs OpLevelCostEstimator::PredictVariable(const OpContext& op_context) const {
Costs result = Costs::ZeroCosts();
result.persistent_memory =
CalculateOutputSize(op_features, &result.inaccurate);
+ result.num_ops_with_unknown_shapes = result.inaccurate;
result.compute_time = kMinComputeTime;
result.execution_time = result.execution_time;
@@ -1198,6 +1212,7 @@ Costs OpLevelCostEstimator::PredictBatchMatMul(
CountBatchMatMulOperations(op_features, &found_unknown_shapes),
op_features);
costs.inaccurate = found_unknown_shapes;
+ costs.num_ops_with_unknown_shapes = found_unknown_shapes;
return costs;
}
@@ -1205,6 +1220,7 @@ Costs OpLevelCostEstimator::PredictMetadata(const OpContext& op_context) const {
const auto& op_features = op_context.op_info;
Costs costs = Costs::ZeroCosts();
costs.max_memory = CalculateOutputSize(op_features, &costs.inaccurate);
+ costs.num_ops_with_unknown_shapes = costs.inaccurate;
// Metadata operations are so cheap we assume they take the minimum amount of
// time we can represent (1 ns).
costs.compute_time = kMinComputeTime;
@@ -1249,6 +1265,7 @@ Costs OpLevelCostEstimator::PredictGatherOrSlice(
const double total_io = input_size + output_size;
Costs costs = PredictOpCountBasedCost(op_count, total_io, op_info);
costs.inaccurate = unknown_shapes;
+ costs.num_ops_with_unknown_shapes = unknown_shapes;
costs.max_memory = output_size;
return costs;
@@ -1390,6 +1407,7 @@ Costs OpLevelCostEstimator::PredictMaxPool(const OpContext& op_context) const {
Costs costs = PredictOpCountBasedCost(
ops, total_input_size + total_output_size, op_info);
costs.inaccurate = found_unknown_shapes;
+ costs.num_ops_with_unknown_shapes = found_unknown_shapes;
costs.max_memory = total_output_size;
return costs;
}
@@ -1432,6 +1450,7 @@ Costs OpLevelCostEstimator::PredictMaxPoolGrad(
Costs costs = PredictOpCountBasedCost(
ops, total_input_size + total_output_size, op_info);
costs.inaccurate = found_unknown_shapes;
+ costs.num_ops_with_unknown_shapes = found_unknown_shapes;
costs.max_memory = total_output_size;
return costs;
}
@@ -1464,6 +1483,7 @@ Costs OpLevelCostEstimator::PredictAvgPool(const OpContext& op_context) const {
Costs costs = PredictOpCountBasedCost(
ops, total_input_size + total_output_size, op_info);
costs.inaccurate = found_unknown_shapes;
+ costs.num_ops_with_unknown_shapes = found_unknown_shapes;
costs.max_memory = total_output_size;
return costs;
}
@@ -1516,6 +1536,7 @@ Costs OpLevelCostEstimator::PredictAvgPoolGrad(
Costs costs = PredictOpCountBasedCost(
ops, total_input_size + total_output_size, op_info);
costs.inaccurate = found_unknown_shapes;
+ costs.num_ops_with_unknown_shapes = found_unknown_shapes;
costs.max_memory = total_output_size;
return costs;
}
@@ -1562,6 +1583,7 @@ Costs OpLevelCostEstimator::PredictFusedBatchNorm(
ops, total_input_size + total_output_size + total_internal_read_size,
op_info);
costs.inaccurate = found_unknown_shapes;
+ costs.num_ops_with_unknown_shapes = found_unknown_shapes;
costs.max_memory = total_output_size;
return costs;
}
@@ -1595,6 +1617,7 @@ Costs OpLevelCostEstimator::PredictFusedBatchNormGrad(
ops, total_input_size + total_output_size + total_internal_read_size,
op_info);
costs.inaccurate = found_unknown_shapes;
+ costs.num_ops_with_unknown_shapes = found_unknown_shapes;
costs.max_memory = total_output_size;
return costs;
}