aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Max Galkin <maxgalkin@google.com>2018-03-21 12:53:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-21 12:55:38 -0700
commitee108441201ecb5fa9536573637623d712f9aa33 (patch)
tree74f4313979bdd43f3292064f951ecd23345e0539
parentbdd6f2253a76c707ff2ce2af9b560478891342eb (diff)
Further improve accuracy of op_level_cost_estimator (Gather, GatherV2, Slice).
PiperOrigin-RevId: 189952132
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc32
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.h2
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc22
3 files changed, 47 insertions, 9 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index 84ad8a3e84..d3ffa03fe2 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -48,6 +48,8 @@ constexpr char kSize[] = "Size";
constexpr char kStopGradient[] = "StopGradient";
constexpr char kPreventGradient[] = "PreventGradient";
constexpr char kGather[] = "Gather";
+constexpr char kGatherV2[] = "GatherV2";
+constexpr char kSlice[] = "Slice";
static const Costs::Duration kMinComputeTime(1);
@@ -169,7 +171,9 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
{kNoOp, wrap(&OpLevelCostEstimator::PredictNoOp)},
- {kGather, wrap(&OpLevelCostEstimator::PredictGather)},
+ {kGather, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)},
+ {kGatherV2, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)},
+ {kSlice, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)},
{kPlaceholder, wrap(&OpLevelCostEstimator::PredictIdentity)},
{kIdentity, wrap(&OpLevelCostEstimator::PredictIdentity)},
@@ -1049,17 +1053,33 @@ Costs OpLevelCostEstimator::PredictMetadata(const OpContext& op_context) const {
return costs;
}
-Costs OpLevelCostEstimator::PredictGather(const OpContext& op_context) const {
- // Gather op can have a very large input, but only the size of the output
- // matters, because indices may select only a very small subset of input.
-
+Costs OpLevelCostEstimator::PredictGatherOrSlice(
+ const OpContext& op_context) const {
+ // Gather & Slice ops can have a very large input, but only access a small
+ // part of it. For these op the size of the output determines the memory cost.
const auto& op_info = op_context.op_info;
bool unknown_shapes = false;
+
+ // Each output element is a copy of some element from input.
+ // For roofline estimate we assume each copy has a unit cost.
const int64 op_count =
CalculateTensorElementCount(op_info.outputs(0), &unknown_shapes);
+
const double output_size = CalculateOutputSize(op_info, &unknown_shapes);
- const double total_io = 2 * output_size;
+ double input_size = output_size;
+ if (op_info.op() == "Slice") {
+ // Add 'begin' & 'size' tensors sizes.
+ input_size +=
+ CalculateTensorElementCount(op_info.inputs(1), &unknown_shapes) +
+ CalculateTensorElementCount(op_info.inputs(2), &unknown_shapes);
+ } else {
+ // Assuming this is "Gather" or "GatherV2" op, add 'indices' size.
+ input_size +=
+ CalculateTensorElementCount(op_info.inputs(1), &unknown_shapes);
+ }
+
+ const double total_io = input_size + output_size;
Costs costs = PredictOpCountBasedCost(op_count, total_io, op_info);
costs.inaccurate = unknown_shapes;
costs.max_memory = output_size;
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
index e5dd31a7a2..1b3babb206 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
@@ -144,7 +144,7 @@ class OpLevelCostEstimator {
Costs PredictVariable(const OpContext& op_context) const;
Costs PredictBatchMatMul(const OpContext& op_context) const;
Costs PredictMetadata(const OpContext& op_context) const;
- Costs PredictGather(const OpContext& op_context) const;
+ Costs PredictGatherOrSlice(const OpContext& op_context) const;
// Utility function for safe division. Returns 0
// if rhs is 0 or negative.
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
index a92f230101..f2a9615dfb 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
@@ -206,9 +206,27 @@ TEST_F(OpLevelCostEstimatorTest, TestGatherCosts) {
DescribeArbitraryRankOutput({16, 10}, DT_FLOAT, &op_context.op_info);
auto cost = estimator_.PredictCosts(op_context);
- EXPECT_EQ(Costs::Duration(128), cost.memory_time);
+ EXPECT_EQ(Costs::Duration(130), cost.memory_time);
EXPECT_EQ(Costs::Duration(16), cost.compute_time);
- EXPECT_EQ(Costs::Duration(144), cost.execution_time);
+ EXPECT_EQ(Costs::Duration(146), cost.execution_time);
+ EXPECT_FALSE(cost.inaccurate);
+}
+
+TEST_F(OpLevelCostEstimatorTest, TestSliceCosts) {
+ OpContext op_context;
+ SetCpuDevice(&op_context.op_info);
+ op_context.op_info.set_op("Slice");
+
+ // Huge first input shouldn't affect Slice execution and memory costs.
+ DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
+ DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
+ DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
+ DescribeArbitraryRankOutput({10, 10}, DT_FLOAT, &op_context.op_info);
+
+ auto cost = estimator_.PredictCosts(op_context);
+ EXPECT_EQ(Costs::Duration(81), cost.memory_time);
+ EXPECT_EQ(Costs::Duration(10), cost.compute_time);
+ EXPECT_EQ(Costs::Duration(91), cost.execution_time);
EXPECT_FALSE(cost.inaccurate);
}