aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2018-02-13 12:31:59 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-13 12:35:57 -0800
commitdedafb73031c5588c1254e8fabd553031b15870a (patch)
tree2cd3121270126c45126b4cdac2fc4f3d34775a79 /tensorflow/core/grappler/costs/op_level_cost_estimator.cc
parenta3496e14f7ca0f77b804b5be87cd43f919a7c09f (diff)
Extract the filter and input shape for Conv2DBackpropFilter/Conv2DBackpropInput
from the corresponding op inputs whenever possible. PiperOrigin-RevId: 185570750
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.cc')
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc63
1 files changed, 48 insertions, 15 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index 76db1afd4a..a57cfdd989 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -724,18 +724,35 @@ int64 OpLevelCostEstimator::CountConv2DBackpropInputOperations(
bool* found_unknown_shapes) const {
int64 ops = 0;
- if (op_features.op() != kConv2dBackpropInput) {
- LOG(ERROR) << "Invalid Operation";
+ DCHECK_EQ(kConv2dBackpropInput, op_features.op());
+
+ if (op_features.inputs_size() < 2) {
+ *found_unknown_shapes = true;
return ops;
}
- if (op_features.outputs_size() != 1) {
- // Need _output_shapes for input shape.
- LOG(ERROR) << "No output shape in Conv2DBackpropInput op.";
- return ops;
+ TensorShapeProto input_shape;
+ if (op_features.inputs(0).has_value()) {
+ const TensorProto& value = op_features.inputs(0).value();
+ if (value.int64_val_size() > 0) {
+ for (int i = 0; i < value.int64_val_size(); ++i) {
+ input_shape.add_dim()->set_size(value.int64_val(i));
+ }
+ } else {
+ for (int i = 0; i < value.int_val_size(); ++i) {
+ input_shape.add_dim()->set_size(value.int_val(i));
+ }
+ }
+ } else if (op_features.outputs_size() == 1) {
+ input_shape = op_features.outputs(0).shape();
+ } else {
+ // Set the minimum filter size that's feasible.
+ for (int i = 0; i < 4; ++i) {
+ input_shape.add_dim()->set_size(1);
+ }
+ *found_unknown_shapes = true;
}
- const auto& input_shape = op_features.outputs(0).shape();
ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs(
input_shape, op_features.inputs(1).shape(), op_features,
found_unknown_shapes);
@@ -758,18 +775,34 @@ int64 OpLevelCostEstimator::CountConv2DBackpropFilterOperations(
const OpInfo& op_features, ConvolutionDimensions* returned_conv_dims,
bool* found_unknown_shapes) const {
int64 ops = 0;
- if (op_features.op() != kConv2dBackpropFilter) {
- LOG(ERROR) << "Invalid Operation";
- return ops;
+ DCHECK_EQ(kConv2dBackpropFilter, op_features.op());
+
+ TensorShapeProto filter_shape;
+ if (op_features.inputs_size() >= 2 && op_features.inputs(1).has_value()) {
+ const TensorProto& value = op_features.inputs(1).value();
+ if (value.int64_val_size() > 0) {
+ for (int i = 0; i < value.int64_val_size(); ++i) {
+ filter_shape.add_dim()->set_size(value.int64_val(i));
+ }
+ } else {
+ for (int i = 0; i < value.int_val_size(); ++i) {
+ filter_shape.add_dim()->set_size(value.int_val(i));
+ }
+ }
+ } else if (op_features.outputs_size() == 1) {
+ filter_shape = op_features.outputs(0).shape();
+ } else {
+ // Set the minimum filter size that's feasible.
+ for (int i = 0; i < 4; ++i) {
+ filter_shape.add_dim()->set_size(1);
+ }
+ *found_unknown_shapes = true;
}
- if (op_features.outputs_size() != 1) {
- // Need _output_shapes for input shape.
- LOG(ERROR) << "No output shape in Conv2DBackpropFilter op.";
+ if (op_features.inputs_size() < 1) {
+ *found_unknown_shapes = true;
return ops;
}
-
- const auto& filter_shape = op_features.outputs(0).shape();
ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs(
op_features.inputs(0).shape(), filter_shape, op_features,
found_unknown_shapes);