diff options
author | Max Galkin <maxgalkin@google.com> | 2018-04-02 21:03:06 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-02 21:07:58 -0700 |
commit | 53eeeb7ac4a876a59ae975a8d6dd8a48f645b7b7 (patch) | |
tree | 1657e10d561ab747ec47749ee8111e443c57d393 /tensorflow/core/grappler/costs/op_level_cost_estimator.cc | |
parent | 3027f580046866cb74d5edf4e41c9406e007234c (diff) |
Re-enable Gather and Slice estimators with output size check.
PiperOrigin-RevId: 191391805
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.cc')
-rw-r--r-- | tensorflow/core/grappler/costs/op_level_cost_estimator.cc | 10 |
1 files changed, 7 insertions, 3 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index 75258d0547..14e46ecdd9 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -202,12 +202,9 @@ OpLevelCostEstimator::OpLevelCostEstimator() { {kNoOp, wrap(&OpLevelCostEstimator::PredictNoOp)}, - // TODO(76227186): re-enable with output size check & test - /* {kGather, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)}, {kGatherV2, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)}, {kSlice, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)}, - */ {kPlaceholder, wrap(&OpLevelCostEstimator::PredictIdentity)}, {kIdentity, wrap(&OpLevelCostEstimator::PredictIdentity)}, @@ -1058,6 +1055,13 @@ Costs OpLevelCostEstimator::PredictGatherOrSlice( // part of it. For these op the size of the output determines the memory cost. const auto& op_info = op_context.op_info; + const int inputs_needed = op_info.op() == "Slice" ? 3 : 2; + if (op_info.outputs_size() == 0 || op_info.inputs_size() < inputs_needed) { + Costs costs = Costs::ZeroCosts(); + costs.inaccurate = true; + return costs; + } + bool unknown_shapes = false; // Each output element is a copy of some element from input. |