diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-02 12:36:08 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-02 12:38:35 -0700 |
commit | fc34c057d9d1118477b3e02870b97305c2d1af86 (patch) | |
tree | 10c330ada9ec81dd40df5676abab4ff5ca97d559 /tensorflow/core/grappler/costs/op_level_cost_estimator.cc | |
parent | 43f5b27f6064b64b7dbcfcae865829e3617a7112 (diff) |
Fix a bug in AvgPoolGrad op cost in extracting input x's shape. AvgPoolGrad
takes a shape tensor; hence, a value should be parsed from inputs(0) to extract
correct shape of x.
PiperOrigin-RevId: 191330762
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.cc')
-rw-r--r-- | tensorflow/core/grappler/costs/op_level_cost_estimator.cc | 29 |
1 files changed, 26 insertions, 3 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index 0f6307cfdf..75258d0547 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -817,6 +817,7 @@ int64 OpLevelCostEstimator::CountConv2DBackpropInputOperations( } if (!shape_found) { // Set the minimum filter size that's feasible. + input_shape.Clear(); for (int i = 0; i < 4; ++i) { input_shape.add_dim()->set_size(1); } @@ -859,6 +860,7 @@ int64 OpLevelCostEstimator::CountConv2DBackpropFilterOperations( } if (!shape_found) { // Set the minimum filter size that's feasible. + filter_shape.Clear(); for (int i = 0; i < 4; ++i) { filter_shape.add_dim()->set_size(1); } @@ -1242,10 +1244,31 @@ Costs OpLevelCostEstimator::PredictAvgPoolGrad( const OpContext& op_context) const { bool found_unknown_shapes = false; const auto& op_info = op_context.op_info; - // x: op_info.inputs(0) + // x's shape: op_info.inputs(0) // y_grad: op_info.inputs(1) - ConvolutionDimensions dims = OpDimensionsFromInputs( - op_info.inputs(0).shape(), op_info, &found_unknown_shapes); + + // Extract x_shape from op_info.inputs(0).value() or op_info.outputs(0). + bool shape_found = false; + TensorShapeProto x_shape; + if (op_info.inputs_size() >= 1 && op_info.inputs(0).has_value()) { + const TensorProto& value = op_info.inputs(0).value(); + shape_found = GetTensorShapeProtoFromTensorProto(value, &x_shape); + } + if (!shape_found && op_info.outputs_size() > 0) { + x_shape = op_info.outputs(0).shape(); + shape_found = true; + } + if (!shape_found) { + // Set the minimum shape that's feasible. + x_shape.Clear(); + for (int i = 0; i < 4; ++i) { + x_shape.add_dim()->set_size(1); + } + found_unknown_shapes = true; + } + + ConvolutionDimensions dims = + OpDimensionsFromInputs(x_shape, op_info, &found_unknown_shapes); int64 ops = 0; if (dims.kx <= dims.sx && dims.ky <= dims.sy) { |