diff options
author | 2018-08-12 16:21:41 -0700 | |
---|---|---|
committer | 2018-08-12 16:21:41 -0700 | |
commit | 9523a98466d16cf01fc76a67b489f1124cf626ac (patch) | |
tree | bd4c460b67fab60c2fb1a6c56bf22d1cbb5391e6 /tensorflow/core/grappler/costs/op_level_cost_estimator.cc | |
parent | 93e950c308071071f35d6dcb35b9f91b8a34876c (diff) | |
parent | 1a22b0b982fa1a953651b98af8f3cd30542048fd (diff) |
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.cc')
-rw-r--r-- | tensorflow/core/grappler/costs/op_level_cost_estimator.cc | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index 5b303f6ccb..6406a4bdbf 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -449,6 +449,7 @@ Costs OpLevelCostEstimator::PredictCwiseOp(const OpContext& op_context) const { if (found_unknown_shapes || !is_known_elementwise_op) { costs.inaccurate = true; } + costs.num_ops_with_unknown_shapes = found_unknown_shapes; return costs; } @@ -469,6 +470,7 @@ Costs OpLevelCostEstimator::PredictOpCountBasedCost( const double total_io_bytes = input_size + output_size; Costs costs = PredictOpCountBasedCost(operations, total_io_bytes, op_info); costs.inaccurate = unknown_shapes; + costs.num_ops_with_unknown_shapes = unknown_shapes; costs.max_memory = output_size; return costs; } @@ -627,6 +629,7 @@ int64 OpLevelCostEstimator::CountMatMulOperations( if (op_features.inputs_size() < 2) { LOG(ERROR) << "Need 2 inputs but got " << op_features.inputs_size(); + // TODO(pcma): Try to separate invalid inputs from unknown shapes *found_unknown_shapes = true; return 0; } @@ -694,11 +697,13 @@ int64 OpLevelCostEstimator::CountBatchMatMulOperations( const OpInfo& op_features, bool* found_unknown_shapes) const { if (op_features.op() != kBatchMatMul) { LOG(ERROR) << "Invalid Operation: " << op_features.op(); + // TODO(pcma): Try to separate invalid inputs from unknown shapes *found_unknown_shapes = true; return 0; } if (op_features.inputs_size() != 2) { LOG(ERROR) << "Expected 2 inputs but got " << op_features.inputs_size(); + // TODO(pcma): Try to separate invalid inputs from unknown shapes *found_unknown_shapes = true; return 0; } @@ -858,6 +863,7 @@ int64 OpLevelCostEstimator::CountConv2DBackpropInputOperations( "kDepthwiseConv2dNativeBackpropInput"; if (op_features.inputs_size() < 2) { + // TODO(pcma): Try to separate invalid inputs from unknown shapes *found_unknown_shapes = true; return ops; } @@ -935,6 +941,7 @@ int64 OpLevelCostEstimator::CountConv2DBackpropFilterOperations( } if (op_features.inputs_size() < 1) { + // TODO(pcma): Try to separate invalid inputs from unknown shapes *found_unknown_shapes = true; return ops; } @@ -1037,6 +1044,7 @@ Costs OpLevelCostEstimator::PredictConv2D(const OpContext& op_context) const { auto costs = PredictOpCountBasedCost( CountConv2DOperations(op_features, &found_unknown_shapes), op_features); costs.inaccurate = found_unknown_shapes; + costs.num_ops_with_unknown_shapes = found_unknown_shapes; return costs; } @@ -1049,6 +1057,7 @@ Costs OpLevelCostEstimator::PredictConv2DBackpropInput( op_features, nullptr, &found_unknown_shapes), op_features); costs.inaccurate = found_unknown_shapes; + costs.num_ops_with_unknown_shapes = found_unknown_shapes; return costs; } @@ -1061,6 +1070,7 @@ Costs OpLevelCostEstimator::PredictConv2DBackpropFilter( op_features, nullptr, &found_unknown_shapes), op_features); costs.inaccurate = found_unknown_shapes; + costs.num_ops_with_unknown_shapes = found_unknown_shapes; return costs; } @@ -1148,6 +1158,7 @@ Costs OpLevelCostEstimator::PredictFusedConv2DBiasActivation( // Construct component operations and run the cost computation. auto costs = PredictFusedOp(op_context_with_output, component_ops); costs.inaccurate |= found_unknown_shapes; + costs.num_ops_with_unknown_shapes = costs.inaccurate; return costs; } @@ -1157,6 +1168,7 @@ Costs OpLevelCostEstimator::PredictMatMul(const OpContext& op_context) const { auto costs = PredictOpCountBasedCost( CountMatMulOperations(op_features, &found_unknown_shapes), op_features); costs.inaccurate = found_unknown_shapes; + costs.num_ops_with_unknown_shapes = found_unknown_shapes; return costs; } @@ -1171,6 +1183,7 @@ Costs OpLevelCostEstimator::PredictIdentity(const OpContext& op_context) const { VLOG(1) << "Op:" << op_features.op() << " Execution Time 0 (ns)"; Costs result = Costs::ZeroCosts(); result.max_memory = CalculateOutputSize(op_features, &result.inaccurate); + result.num_ops_with_unknown_shapes = result.inaccurate; // Assign the minimum amount of time we can represent to the identity op since // it tends to be really cheap. result.compute_time = kMinComputeTime; @@ -1184,6 +1197,7 @@ Costs OpLevelCostEstimator::PredictVariable(const OpContext& op_context) const { Costs result = Costs::ZeroCosts(); result.persistent_memory = CalculateOutputSize(op_features, &result.inaccurate); + result.num_ops_with_unknown_shapes = result.inaccurate; result.compute_time = kMinComputeTime; result.execution_time = result.execution_time; @@ -1198,6 +1212,7 @@ Costs OpLevelCostEstimator::PredictBatchMatMul( CountBatchMatMulOperations(op_features, &found_unknown_shapes), op_features); costs.inaccurate = found_unknown_shapes; + costs.num_ops_with_unknown_shapes = found_unknown_shapes; return costs; } @@ -1205,6 +1220,7 @@ Costs OpLevelCostEstimator::PredictMetadata(const OpContext& op_context) const { const auto& op_features = op_context.op_info; Costs costs = Costs::ZeroCosts(); costs.max_memory = CalculateOutputSize(op_features, &costs.inaccurate); + costs.num_ops_with_unknown_shapes = costs.inaccurate; // Metadata operations are so cheap we assume they take the minimum amount of // time we can represent (1 ns). costs.compute_time = kMinComputeTime; @@ -1249,6 +1265,7 @@ Costs OpLevelCostEstimator::PredictGatherOrSlice( const double total_io = input_size + output_size; Costs costs = PredictOpCountBasedCost(op_count, total_io, op_info); costs.inaccurate = unknown_shapes; + costs.num_ops_with_unknown_shapes = unknown_shapes; costs.max_memory = output_size; return costs; @@ -1390,6 +1407,7 @@ Costs OpLevelCostEstimator::PredictMaxPool(const OpContext& op_context) const { Costs costs = PredictOpCountBasedCost( ops, total_input_size + total_output_size, op_info); costs.inaccurate = found_unknown_shapes; + costs.num_ops_with_unknown_shapes = found_unknown_shapes; costs.max_memory = total_output_size; return costs; } @@ -1432,6 +1450,7 @@ Costs OpLevelCostEstimator::PredictMaxPoolGrad( Costs costs = PredictOpCountBasedCost( ops, total_input_size + total_output_size, op_info); costs.inaccurate = found_unknown_shapes; + costs.num_ops_with_unknown_shapes = found_unknown_shapes; costs.max_memory = total_output_size; return costs; } @@ -1464,6 +1483,7 @@ Costs OpLevelCostEstimator::PredictAvgPool(const OpContext& op_context) const { Costs costs = PredictOpCountBasedCost( ops, total_input_size + total_output_size, op_info); costs.inaccurate = found_unknown_shapes; + costs.num_ops_with_unknown_shapes = found_unknown_shapes; costs.max_memory = total_output_size; return costs; } @@ -1516,6 +1536,7 @@ Costs OpLevelCostEstimator::PredictAvgPoolGrad( Costs costs = PredictOpCountBasedCost( ops, total_input_size + total_output_size, op_info); costs.inaccurate = found_unknown_shapes; + costs.num_ops_with_unknown_shapes = found_unknown_shapes; costs.max_memory = total_output_size; return costs; } @@ -1562,6 +1583,7 @@ Costs OpLevelCostEstimator::PredictFusedBatchNorm( ops, total_input_size + total_output_size + total_internal_read_size, op_info); costs.inaccurate = found_unknown_shapes; + costs.num_ops_with_unknown_shapes = found_unknown_shapes; costs.max_memory = total_output_size; return costs; } @@ -1595,6 +1617,7 @@ Costs OpLevelCostEstimator::PredictFusedBatchNormGrad( ops, total_input_size + total_output_size + total_internal_read_size, op_info); costs.inaccurate = found_unknown_shapes; + costs.num_ops_with_unknown_shapes = found_unknown_shapes; costs.max_memory = total_output_size; return costs; } |