aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-02-22 15:34:50 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-22 15:38:52 -0800
commit848c53fb11cab2631695cdb6c38bbdfeee972a75 (patch)
tree9c9ae9d599edaf3c72e8194c24d159ad30920653
parent6006f46dd7531b112360b831aa61de6c46618166 (diff)
Implement the logic to parse TensorProto (the tensor value for input or filter shape info) in op_level_cost_estimator.
PiperOrigin-RevId: 186685409
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc86
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.h3
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc105
3 files changed, 172 insertions, 22 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index a57cfdd989..983b6891f1 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -718,6 +718,56 @@ int64 OpLevelCostEstimator::CountBatchMatMulOperations(
return ops;
}
+bool GetTensorShapeProtoFromTensorProto(const TensorProto& tensor_proto,
+ TensorShapeProto* tensor_shape_proto) {
+ tensor_shape_proto->Clear();
+ // First convert TensorProto into Tensor class so that it correctly parses
+ // data values within TensorProto (whether it's in int_val, int64_val,
+ // tensor_content, or anything.
+ Tensor tensor(tensor_proto.dtype());
+ if (!tensor.FromProto(tensor_proto)) {
+ LOG(WARNING) << "GetTensorShapeProtoFromTensorProto() -- "
+ << "failed to parse TensorProto: "
+ << tensor_proto.DebugString();
+ return false;
+ }
+ if (tensor.dims() != 1) {
+ LOG(WARNING) << "GetTensorShapeProtoFromTensorProto() -- "
+ << "tensor is not 1D: " << tensor.dims();
+ return false;
+ }
+ // Then, convert it back to TensorProto using AsProtoField, which makes sure
+ // the data is in int_val, int64_val, or such repeated data fields, not in
+ // tensor_content.
+ TensorProto temp_tensor;
+ tensor.AsProtoField(&temp_tensor);
+
+#define TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(type) \
+ do { \
+ for (const auto& value : temp_tensor.type##_val()) { \
+ tensor_shape_proto->add_dim()->set_size(value); \
+ } \
+ } while (0)
+
+ if (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT16 ||
+ tensor.dtype() == DT_INT8 || tensor.dtype() == DT_UINT8) {
+ TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(int);
+ } else if (tensor.dtype() == DT_INT64) {
+ TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(int64);
+ } else if (tensor.dtype() == DT_UINT32) {
+ TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(uint32);
+ } else if (tensor.dtype() == DT_UINT64) {
+ TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(uint64);
+ } else {
+ LOG(WARNING) << "GetTensorShapeProtoFromTensorProto() -- "
+ << "Unsupported dtype: " << tensor.dtype();
+ return false;
+ }
+#undef TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO
+
+ return true;
+}
+
// TODO(cliffy): Dedup this method and CountConv2DBackpropFilterOperations.
int64 OpLevelCostEstimator::CountConv2DBackpropInputOperations(
const OpInfo& op_features, ConvolutionDimensions* returned_conv_dims,
@@ -732,20 +782,16 @@ int64 OpLevelCostEstimator::CountConv2DBackpropInputOperations(
}
TensorShapeProto input_shape;
+ bool shape_found = false;
if (op_features.inputs(0).has_value()) {
const TensorProto& value = op_features.inputs(0).value();
- if (value.int64_val_size() > 0) {
- for (int i = 0; i < value.int64_val_size(); ++i) {
- input_shape.add_dim()->set_size(value.int64_val(i));
- }
- } else {
- for (int i = 0; i < value.int_val_size(); ++i) {
- input_shape.add_dim()->set_size(value.int_val(i));
- }
- }
- } else if (op_features.outputs_size() == 1) {
+ shape_found = GetTensorShapeProtoFromTensorProto(value, &input_shape);
+ }
+ if (!shape_found && op_features.outputs_size() == 1) {
input_shape = op_features.outputs(0).shape();
- } else {
+ shape_found = true;
+ }
+ if (!shape_found) {
// Set the minimum filter size that's feasible.
for (int i = 0; i < 4; ++i) {
input_shape.add_dim()->set_size(1);
@@ -778,20 +824,16 @@ int64 OpLevelCostEstimator::CountConv2DBackpropFilterOperations(
DCHECK_EQ(kConv2dBackpropFilter, op_features.op());
TensorShapeProto filter_shape;
+ bool shape_found = false;
if (op_features.inputs_size() >= 2 && op_features.inputs(1).has_value()) {
const TensorProto& value = op_features.inputs(1).value();
- if (value.int64_val_size() > 0) {
- for (int i = 0; i < value.int64_val_size(); ++i) {
- filter_shape.add_dim()->set_size(value.int64_val(i));
- }
- } else {
- for (int i = 0; i < value.int_val_size(); ++i) {
- filter_shape.add_dim()->set_size(value.int_val(i));
- }
- }
- } else if (op_features.outputs_size() == 1) {
+ shape_found = GetTensorShapeProtoFromTensorProto(value, &filter_shape);
+ }
+ if (!shape_found && op_features.outputs_size() == 1) {
filter_shape = op_features.outputs(0).shape();
- } else {
+ shape_found = true;
+ }
+ if (!shape_found) {
// Set the minimum filter size that's feasible.
for (int i = 0; i < 4; ++i) {
filter_shape.add_dim()->set_size(1);
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
index a292e5e97f..7bb530fe31 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
@@ -28,6 +28,9 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
+bool GetTensorShapeProtoFromTensorProto(const TensorProto& tensor_proto,
+ TensorShapeProto* tensor_shape_proto);
+
class OpLevelCostEstimator {
public:
OpLevelCostEstimator();
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
index 60fc783472..583d2619b2 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
@@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/platform/test.h"
@@ -247,5 +249,108 @@ TEST_F(OpLevelCostEstimatorTest, BatchMatMul) {
EXPECT_NE(matmul_inaccurate, batch_matmul_inaccurate);
}
+// Helper functions for testing GetTensorShapeProtoFromTensorProto().
+void GetTensorProto(const DataType dtype, const std::vector<int64>& shape,
+ const std::vector<int64> values, const bool tensor_content,
+ TensorProto* tensor_proto) {
+ tensor_proto->Clear();
+ TensorProto temp_tensor_proto;
+ temp_tensor_proto.set_dtype(dtype);
+ for (const auto& x : shape) {
+ temp_tensor_proto.mutable_tensor_shape()->add_dim()->set_size(x);
+ }
+ for (const auto& x : values) {
+ if (dtype == DT_INT64) {
+ temp_tensor_proto.add_int64_val(x);
+ } else if (dtype == DT_INT32 || dtype == DT_INT16 || dtype == DT_INT8 ||
+ dtype == DT_UINT8) {
+ temp_tensor_proto.add_int_val(x);
+ } else if (dtype == DT_UINT32) {
+ temp_tensor_proto.add_uint32_val(x);
+ } else if (dtype == DT_UINT64) {
+ temp_tensor_proto.add_uint64_val(x);
+ } else {
+ CHECK(false) << "Unsupported dtype: " << dtype;
+ }
+ }
+ Tensor tensor(dtype);
+ CHECK(tensor.FromProto(temp_tensor_proto));
+ if (tensor_content) {
+ tensor.AsProtoTensorContent(tensor_proto);
+ } else {
+ tensor.AsProtoField(tensor_proto);
+ }
+}
+
+void ExpectTensorShape(const std::vector<int64>& expected,
+ const TensorShapeProto& tensor_shape_proto) {
+ TensorShape tensor_shape_expected(expected);
+ TensorShape tensor_shape(tensor_shape_proto);
+
+ LOG(INFO) << "Expected: " << tensor_shape_expected.DebugString();
+ LOG(INFO) << "TensorShape: " << tensor_shape.DebugString();
+ EXPECT_TRUE(tensor_shape_expected == tensor_shape);
+}
+
+TEST_F(OpLevelCostEstimatorTest, GetTensorShapeProtoFromTensorProto) {
+ TensorProto tensor_proto;
+ TensorShapeProto tensor_shape_proto;
+
+ // Dimention larger than max value; should fail while converting to Tensor
+ // class.
+ tensor_proto.mutable_tensor_shape()->add_dim()->set_size(255);
+ EXPECT_FALSE(
+ GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
+
+ tensor_proto.Clear();
+ // Expect only 1D shape.
+ tensor_proto.mutable_tensor_shape()->add_dim()->set_size(1);
+ tensor_proto.mutable_tensor_shape()->add_dim()->set_size(2);
+ EXPECT_FALSE(
+ GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
+
+ // Expect only handle integer data types.
+ GetTensorProto(DT_FLOAT, {}, {}, /*tensor_content=*/false, &tensor_proto);
+ EXPECT_FALSE(
+ GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
+
+ // Check GetTensorShapeProtoFromTensorProto() resturns correct values.
+ {
+ std::vector<int64> shape_expected = {10, 20, 30, 40};
+ GetTensorProto(DT_INT32, {4}, shape_expected, /*tensor_content=*/false,
+ &tensor_proto);
+ EXPECT_TRUE(
+ GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
+ ExpectTensorShape(shape_expected, tensor_shape_proto);
+ }
+
+ {
+ std::vector<int64> shape_expected = {40, 20, 90, 40};
+ GetTensorProto(DT_INT64, {4}, shape_expected, /*tensor_content=*/false,
+ &tensor_proto);
+ EXPECT_TRUE(
+ GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
+ ExpectTensorShape(shape_expected, tensor_shape_proto);
+ }
+
+ {
+ std::vector<int64> shape_expected = {10, 20, 30, 40};
+ GetTensorProto(DT_INT32, {4}, shape_expected, /*tensor_content=*/true,
+ &tensor_proto);
+ EXPECT_TRUE(
+ GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
+ ExpectTensorShape(shape_expected, tensor_shape_proto);
+ }
+
+ {
+ std::vector<int64> shape_expected = {40, 20, 90, 40};
+ GetTensorProto(DT_INT64, {4}, shape_expected, /*tensor_content=*/true,
+ &tensor_proto);
+ EXPECT_TRUE(
+ GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
+ ExpectTensorShape(shape_expected, tensor_shape_proto);
+ }
+}
+
} // end namespace grappler
} // end namespace tensorflow