diff options
author | Max Galkin <maxgalkin@google.com> | 2018-03-21 12:53:04 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-21 12:55:38 -0700 |
commit | ee108441201ecb5fa9536573637623d712f9aa33 (patch) | |
tree | 74f4313979bdd43f3292064f951ecd23345e0539 /tensorflow/core/grappler/costs/op_level_cost_estimator.cc | |
parent | bdd6f2253a76c707ff2ce2af9b560478891342eb (diff) |
Further improve accuracy of op_level_cost_estimator (Gather, GatherV2, Slice).
PiperOrigin-RevId: 189952132
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.cc')
-rw-r--r-- | tensorflow/core/grappler/costs/op_level_cost_estimator.cc | 32 |
1 files changed, 26 insertions, 6 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index 84ad8a3e84..d3ffa03fe2 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -48,6 +48,8 @@ constexpr char kSize[] = "Size"; constexpr char kStopGradient[] = "StopGradient"; constexpr char kPreventGradient[] = "PreventGradient"; constexpr char kGather[] = "Gather"; +constexpr char kGatherV2[] = "GatherV2"; +constexpr char kSlice[] = "Slice"; static const Costs::Duration kMinComputeTime(1); @@ -169,7 +171,9 @@ OpLevelCostEstimator::OpLevelCostEstimator() { {kNoOp, wrap(&OpLevelCostEstimator::PredictNoOp)}, - {kGather, wrap(&OpLevelCostEstimator::PredictGather)}, + {kGather, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)}, + {kGatherV2, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)}, + {kSlice, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)}, {kPlaceholder, wrap(&OpLevelCostEstimator::PredictIdentity)}, {kIdentity, wrap(&OpLevelCostEstimator::PredictIdentity)}, @@ -1049,17 +1053,33 @@ Costs OpLevelCostEstimator::PredictMetadata(const OpContext& op_context) const { return costs; } -Costs OpLevelCostEstimator::PredictGather(const OpContext& op_context) const { - // Gather op can have a very large input, but only the size of the output - // matters, because indices may select only a very small subset of input. - +Costs OpLevelCostEstimator::PredictGatherOrSlice( + const OpContext& op_context) const { + // Gather & Slice ops can have a very large input, but only access a small + // part of it. For these op the size of the output determines the memory cost. const auto& op_info = op_context.op_info; bool unknown_shapes = false; + + // Each output element is a copy of some element from input. + // For roofline estimate we assume each copy has a unit cost. const int64 op_count = CalculateTensorElementCount(op_info.outputs(0), &unknown_shapes); + const double output_size = CalculateOutputSize(op_info, &unknown_shapes); - const double total_io = 2 * output_size; + double input_size = output_size; + if (op_info.op() == "Slice") { + // Add 'begin' & 'size' tensors sizes. + input_size += + CalculateTensorElementCount(op_info.inputs(1), &unknown_shapes) + + CalculateTensorElementCount(op_info.inputs(2), &unknown_shapes); + } else { + // Assuming this is "Gather" or "GatherV2" op, add 'indices' size. + input_size += + CalculateTensorElementCount(op_info.inputs(1), &unknown_shapes); + } + + const double total_io = input_size + output_size; Costs costs = PredictOpCountBasedCost(op_count, total_io, op_info); costs.inaccurate = unknown_shapes; costs.max_memory = output_size; |