aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-02 12:36:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-02 12:38:35 -0700
commitfc34c057d9d1118477b3e02870b97305c2d1af86 (patch)
tree10c330ada9ec81dd40df5676abab4ff5ca97d559 /tensorflow/core/grappler/costs/op_level_cost_estimator.cc
parent43f5b27f6064b64b7dbcfcae865829e3617a7112 (diff)
Fix a bug in AvgPoolGrad op cost in extracting input x's shape. AvgPoolGrad
takes a shape tensor; hence, a value should be parsed from inputs(0) to extract correct shape of x. PiperOrigin-RevId: 191330762
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.cc')
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc29
1 files changed, 26 insertions, 3 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index 0f6307cfdf..75258d0547 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -817,6 +817,7 @@ int64 OpLevelCostEstimator::CountConv2DBackpropInputOperations(
}
if (!shape_found) {
// Set the minimum filter size that's feasible.
+ input_shape.Clear();
for (int i = 0; i < 4; ++i) {
input_shape.add_dim()->set_size(1);
}
@@ -859,6 +860,7 @@ int64 OpLevelCostEstimator::CountConv2DBackpropFilterOperations(
}
if (!shape_found) {
// Set the minimum filter size that's feasible.
+ filter_shape.Clear();
for (int i = 0; i < 4; ++i) {
filter_shape.add_dim()->set_size(1);
}
@@ -1242,10 +1244,31 @@ Costs OpLevelCostEstimator::PredictAvgPoolGrad(
const OpContext& op_context) const {
bool found_unknown_shapes = false;
const auto& op_info = op_context.op_info;
- // x: op_info.inputs(0)
+ // x's shape: op_info.inputs(0)
// y_grad: op_info.inputs(1)
- ConvolutionDimensions dims = OpDimensionsFromInputs(
- op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
+
+ // Extract x_shape from op_info.inputs(0).value() or op_info.outputs(0).
+ bool shape_found = false;
+ TensorShapeProto x_shape;
+ if (op_info.inputs_size() >= 1 && op_info.inputs(0).has_value()) {
+ const TensorProto& value = op_info.inputs(0).value();
+ shape_found = GetTensorShapeProtoFromTensorProto(value, &x_shape);
+ }
+ if (!shape_found && op_info.outputs_size() > 0) {
+ x_shape = op_info.outputs(0).shape();
+ shape_found = true;
+ }
+ if (!shape_found) {
+ // Set the minimum shape that's feasible.
+ x_shape.Clear();
+ for (int i = 0; i < 4; ++i) {
+ x_shape.add_dim()->set_size(1);
+ }
+ found_unknown_shapes = true;
+ }
+
+ ConvolutionDimensions dims =
+ OpDimensionsFromInputs(x_shape, op_info, &found_unknown_shapes);
int64 ops = 0;
if (dims.kx <= dims.sx && dims.ky <= dims.sy) {