diff options
author | Max Galkin <maxgalkin@google.com> | 2018-03-21 12:53:04 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-21 12:55:38 -0700 |
commit | ee108441201ecb5fa9536573637623d712f9aa33 (patch) | |
tree | 74f4313979bdd43f3292064f951ecd23345e0539 | |
parent | bdd6f2253a76c707ff2ce2af9b560478891342eb (diff) |
Further improve accuracy of op_level_cost_estimator (Gather, GatherV2, Slice).
PiperOrigin-RevId: 189952132
3 files changed, 47 insertions, 9 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index 84ad8a3e84..d3ffa03fe2 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -48,6 +48,8 @@ constexpr char kSize[] = "Size"; constexpr char kStopGradient[] = "StopGradient"; constexpr char kPreventGradient[] = "PreventGradient"; constexpr char kGather[] = "Gather"; +constexpr char kGatherV2[] = "GatherV2"; +constexpr char kSlice[] = "Slice"; static const Costs::Duration kMinComputeTime(1); @@ -169,7 +171,9 @@ OpLevelCostEstimator::OpLevelCostEstimator() { {kNoOp, wrap(&OpLevelCostEstimator::PredictNoOp)}, - {kGather, wrap(&OpLevelCostEstimator::PredictGather)}, + {kGather, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)}, + {kGatherV2, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)}, + {kSlice, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)}, {kPlaceholder, wrap(&OpLevelCostEstimator::PredictIdentity)}, {kIdentity, wrap(&OpLevelCostEstimator::PredictIdentity)}, @@ -1049,17 +1053,33 @@ Costs OpLevelCostEstimator::PredictMetadata(const OpContext& op_context) const { return costs; } -Costs OpLevelCostEstimator::PredictGather(const OpContext& op_context) const { - // Gather op can have a very large input, but only the size of the output - // matters, because indices may select only a very small subset of input. - +Costs OpLevelCostEstimator::PredictGatherOrSlice( + const OpContext& op_context) const { + // Gather & Slice ops can have a very large input, but only access a small + // part of it. For these op the size of the output determines the memory cost. const auto& op_info = op_context.op_info; bool unknown_shapes = false; + + // Each output element is a copy of some element from input. + // For roofline estimate we assume each copy has a unit cost. const int64 op_count = CalculateTensorElementCount(op_info.outputs(0), &unknown_shapes); + const double output_size = CalculateOutputSize(op_info, &unknown_shapes); - const double total_io = 2 * output_size; + double input_size = output_size; + if (op_info.op() == "Slice") { + // Add 'begin' & 'size' tensors sizes. + input_size += + CalculateTensorElementCount(op_info.inputs(1), &unknown_shapes) + + CalculateTensorElementCount(op_info.inputs(2), &unknown_shapes); + } else { + // Assuming this is "Gather" or "GatherV2" op, add 'indices' size. + input_size += + CalculateTensorElementCount(op_info.inputs(1), &unknown_shapes); + } + + const double total_io = input_size + output_size; Costs costs = PredictOpCountBasedCost(op_count, total_io, op_info); costs.inaccurate = unknown_shapes; costs.max_memory = output_size; diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h index e5dd31a7a2..1b3babb206 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h @@ -144,7 +144,7 @@ class OpLevelCostEstimator { Costs PredictVariable(const OpContext& op_context) const; Costs PredictBatchMatMul(const OpContext& op_context) const; Costs PredictMetadata(const OpContext& op_context) const; - Costs PredictGather(const OpContext& op_context) const; + Costs PredictGatherOrSlice(const OpContext& op_context) const; // Utility function for safe division. Returns 0 // if rhs is 0 or negative. diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc index a92f230101..f2a9615dfb 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc @@ -206,9 +206,27 @@ TEST_F(OpLevelCostEstimatorTest, TestGatherCosts) { DescribeArbitraryRankOutput({16, 10}, DT_FLOAT, &op_context.op_info); auto cost = estimator_.PredictCosts(op_context); - EXPECT_EQ(Costs::Duration(128), cost.memory_time); + EXPECT_EQ(Costs::Duration(130), cost.memory_time); EXPECT_EQ(Costs::Duration(16), cost.compute_time); - EXPECT_EQ(Costs::Duration(144), cost.execution_time); + EXPECT_EQ(Costs::Duration(146), cost.execution_time); + EXPECT_FALSE(cost.inaccurate); +} + +TEST_F(OpLevelCostEstimatorTest, TestSliceCosts) { + OpContext op_context; + SetCpuDevice(&op_context.op_info); + op_context.op_info.set_op("Slice"); + + // Huge first input shouldn't affect Slice execution and memory costs. + DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info); + DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info); + DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info); + DescribeArbitraryRankOutput({10, 10}, DT_FLOAT, &op_context.op_info); + + auto cost = estimator_.PredictCosts(op_context); + EXPECT_EQ(Costs::Duration(81), cost.memory_time); + EXPECT_EQ(Costs::Duration(10), cost.compute_time); + EXPECT_EQ(Costs::Duration(91), cost.execution_time); EXPECT_FALSE(cost.inaccurate); } |