aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
diff options
context:
space:
mode:
authorGravatar Max Galkin <maxgalkin@google.com>2018-03-21 12:53:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-21 12:55:38 -0700
commitee108441201ecb5fa9536573637623d712f9aa33 (patch)
tree74f4313979bdd43f3292064f951ecd23345e0539 /tensorflow/core/grappler/costs/op_level_cost_estimator.cc
parentbdd6f2253a76c707ff2ce2af9b560478891342eb (diff)
Further improve accuracy of op_level_cost_estimator (Gather, GatherV2, Slice).
PiperOrigin-RevId: 189952132
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.cc')
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc32
1 files changed, 26 insertions, 6 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index 84ad8a3e84..d3ffa03fe2 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -48,6 +48,8 @@ constexpr char kSize[] = "Size";
constexpr char kStopGradient[] = "StopGradient";
constexpr char kPreventGradient[] = "PreventGradient";
constexpr char kGather[] = "Gather";
+constexpr char kGatherV2[] = "GatherV2";
+constexpr char kSlice[] = "Slice";
static const Costs::Duration kMinComputeTime(1);
@@ -169,7 +171,9 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
{kNoOp, wrap(&OpLevelCostEstimator::PredictNoOp)},
- {kGather, wrap(&OpLevelCostEstimator::PredictGather)},
+ {kGather, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)},
+ {kGatherV2, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)},
+ {kSlice, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)},
{kPlaceholder, wrap(&OpLevelCostEstimator::PredictIdentity)},
{kIdentity, wrap(&OpLevelCostEstimator::PredictIdentity)},
@@ -1049,17 +1053,33 @@ Costs OpLevelCostEstimator::PredictMetadata(const OpContext& op_context) const {
return costs;
}
-Costs OpLevelCostEstimator::PredictGather(const OpContext& op_context) const {
- // Gather op can have a very large input, but only the size of the output
- // matters, because indices may select only a very small subset of input.
-
+Costs OpLevelCostEstimator::PredictGatherOrSlice(
+ const OpContext& op_context) const {
+ // Gather & Slice ops can have a very large input, but only access a small
+ // part of it. For these op the size of the output determines the memory cost.
const auto& op_info = op_context.op_info;
bool unknown_shapes = false;
+
+ // Each output element is a copy of some element from input.
+ // For roofline estimate we assume each copy has a unit cost.
const int64 op_count =
CalculateTensorElementCount(op_info.outputs(0), &unknown_shapes);
+
const double output_size = CalculateOutputSize(op_info, &unknown_shapes);
- const double total_io = 2 * output_size;
+ double input_size = output_size;
+ if (op_info.op() == "Slice") {
+ // Add 'begin' & 'size' tensors sizes.
+ input_size +=
+ CalculateTensorElementCount(op_info.inputs(1), &unknown_shapes) +
+ CalculateTensorElementCount(op_info.inputs(2), &unknown_shapes);
+ } else {
+ // Assuming this is "Gather" or "GatherV2" op, add 'indices' size.
+ input_size +=
+ CalculateTensorElementCount(op_info.inputs(1), &unknown_shapes);
+ }
+
+ const double total_io = input_size + output_size;
Costs costs = PredictOpCountBasedCost(op_count, total_io, op_info);
costs.inaccurate = unknown_shapes;
costs.max_memory = output_size;