aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
diff options
context:
space:
mode:
authorGravatar Max Galkin <maxgalkin@google.com>2018-04-02 21:03:06 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-02 21:07:58 -0700
commit53eeeb7ac4a876a59ae975a8d6dd8a48f645b7b7 (patch)
tree1657e10d561ab747ec47749ee8111e443c57d393 /tensorflow/core/grappler/costs/op_level_cost_estimator.cc
parent3027f580046866cb74d5edf4e41c9406e007234c (diff)
Re-enable Gather and Slice estimators with output size check.
PiperOrigin-RevId: 191391805
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.cc')
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc10
1 files changed, 7 insertions, 3 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index 75258d0547..14e46ecdd9 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -202,12 +202,9 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
{kNoOp, wrap(&OpLevelCostEstimator::PredictNoOp)},
- // TODO(76227186): re-enable with output size check & test
- /*
{kGather, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)},
{kGatherV2, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)},
{kSlice, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)},
- */
{kPlaceholder, wrap(&OpLevelCostEstimator::PredictIdentity)},
{kIdentity, wrap(&OpLevelCostEstimator::PredictIdentity)},
@@ -1058,6 +1055,13 @@ Costs OpLevelCostEstimator::PredictGatherOrSlice(
// part of it. For these op the size of the output determines the memory cost.
const auto& op_info = op_context.op_info;
+ const int inputs_needed = op_info.op() == "Slice" ? 3 : 2;
+ if (op_info.outputs_size() == 0 || op_info.inputs_size() < inputs_needed) {
+ Costs costs = Costs::ZeroCosts();
+ costs.inaccurate = true;
+ return costs;
+ }
+
bool unknown_shapes = false;
// Each output element is a copy of some element from input.