diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-02-22 15:34:50 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-22 15:38:52 -0800 |
commit | 848c53fb11cab2631695cdb6c38bbdfeee972a75 (patch) | |
tree | 9c9ae9d599edaf3c72e8194c24d159ad30920653 | |
parent | 6006f46dd7531b112360b831aa61de6c46618166 (diff) |
Implement the logic to parse TensorProto (the tensor value for input or filter shape info) in op_level_cost_estimator.
PiperOrigin-RevId: 186685409
3 files changed, 172 insertions, 22 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index a57cfdd989..983b6891f1 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -718,6 +718,56 @@ int64 OpLevelCostEstimator::CountBatchMatMulOperations( return ops; } +bool GetTensorShapeProtoFromTensorProto(const TensorProto& tensor_proto, + TensorShapeProto* tensor_shape_proto) { + tensor_shape_proto->Clear(); + // First convert TensorProto into Tensor class so that it correctly parses + // data values within TensorProto (whether it's in int_val, int64_val, + // tensor_content, or anything. + Tensor tensor(tensor_proto.dtype()); + if (!tensor.FromProto(tensor_proto)) { + LOG(WARNING) << "GetTensorShapeProtoFromTensorProto() -- " + << "failed to parse TensorProto: " + << tensor_proto.DebugString(); + return false; + } + if (tensor.dims() != 1) { + LOG(WARNING) << "GetTensorShapeProtoFromTensorProto() -- " + << "tensor is not 1D: " << tensor.dims(); + return false; + } + // Then, convert it back to TensorProto using AsProtoField, which makes sure + // the data is in int_val, int64_val, or such repeated data fields, not in + // tensor_content. + TensorProto temp_tensor; + tensor.AsProtoField(&temp_tensor); + +#define TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(type) \ + do { \ + for (const auto& value : temp_tensor.type##_val()) { \ + tensor_shape_proto->add_dim()->set_size(value); \ + } \ + } while (0) + + if (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT16 || + tensor.dtype() == DT_INT8 || tensor.dtype() == DT_UINT8) { + TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(int); + } else if (tensor.dtype() == DT_INT64) { + TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(int64); + } else if (tensor.dtype() == DT_UINT32) { + TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(uint32); + } else if (tensor.dtype() == DT_UINT64) { + TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(uint64); + } else { + LOG(WARNING) << "GetTensorShapeProtoFromTensorProto() -- " + << "Unsupported dtype: " << tensor.dtype(); + return false; + } +#undef TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO + + return true; +} + // TODO(cliffy): Dedup this method and CountConv2DBackpropFilterOperations. int64 OpLevelCostEstimator::CountConv2DBackpropInputOperations( const OpInfo& op_features, ConvolutionDimensions* returned_conv_dims, @@ -732,20 +782,16 @@ int64 OpLevelCostEstimator::CountConv2DBackpropInputOperations( } TensorShapeProto input_shape; + bool shape_found = false; if (op_features.inputs(0).has_value()) { const TensorProto& value = op_features.inputs(0).value(); - if (value.int64_val_size() > 0) { - for (int i = 0; i < value.int64_val_size(); ++i) { - input_shape.add_dim()->set_size(value.int64_val(i)); - } - } else { - for (int i = 0; i < value.int_val_size(); ++i) { - input_shape.add_dim()->set_size(value.int_val(i)); - } - } - } else if (op_features.outputs_size() == 1) { + shape_found = GetTensorShapeProtoFromTensorProto(value, &input_shape); + } + if (!shape_found && op_features.outputs_size() == 1) { input_shape = op_features.outputs(0).shape(); - } else { + shape_found = true; + } + if (!shape_found) { // Set the minimum filter size that's feasible. for (int i = 0; i < 4; ++i) { input_shape.add_dim()->set_size(1); @@ -778,20 +824,16 @@ int64 OpLevelCostEstimator::CountConv2DBackpropFilterOperations( DCHECK_EQ(kConv2dBackpropFilter, op_features.op()); TensorShapeProto filter_shape; + bool shape_found = false; if (op_features.inputs_size() >= 2 && op_features.inputs(1).has_value()) { const TensorProto& value = op_features.inputs(1).value(); - if (value.int64_val_size() > 0) { - for (int i = 0; i < value.int64_val_size(); ++i) { - filter_shape.add_dim()->set_size(value.int64_val(i)); - } - } else { - for (int i = 0; i < value.int_val_size(); ++i) { - filter_shape.add_dim()->set_size(value.int_val(i)); - } - } - } else if (op_features.outputs_size() == 1) { + shape_found = GetTensorShapeProtoFromTensorProto(value, &filter_shape); + } + if (!shape_found && op_features.outputs_size() == 1) { filter_shape = op_features.outputs(0).shape(); - } else { + shape_found = true; + } + if (!shape_found) { // Set the minimum filter size that's feasible. for (int i = 0; i < 4; ++i) { filter_shape.add_dim()->set_size(1); diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h index a292e5e97f..7bb530fe31 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h @@ -28,6 +28,9 @@ limitations under the License. namespace tensorflow { namespace grappler { +bool GetTensorShapeProtoFromTensorProto(const TensorProto& tensor_proto, + TensorShapeProto* tensor_shape_proto); + class OpLevelCostEstimator { public: OpLevelCostEstimator(); diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc index 60fc783472..583d2619b2 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/test.h" @@ -247,5 +249,108 @@ TEST_F(OpLevelCostEstimatorTest, BatchMatMul) { EXPECT_NE(matmul_inaccurate, batch_matmul_inaccurate); } +// Helper functions for testing GetTensorShapeProtoFromTensorProto(). +void GetTensorProto(const DataType dtype, const std::vector<int64>& shape, + const std::vector<int64> values, const bool tensor_content, + TensorProto* tensor_proto) { + tensor_proto->Clear(); + TensorProto temp_tensor_proto; + temp_tensor_proto.set_dtype(dtype); + for (const auto& x : shape) { + temp_tensor_proto.mutable_tensor_shape()->add_dim()->set_size(x); + } + for (const auto& x : values) { + if (dtype == DT_INT64) { + temp_tensor_proto.add_int64_val(x); + } else if (dtype == DT_INT32 || dtype == DT_INT16 || dtype == DT_INT8 || + dtype == DT_UINT8) { + temp_tensor_proto.add_int_val(x); + } else if (dtype == DT_UINT32) { + temp_tensor_proto.add_uint32_val(x); + } else if (dtype == DT_UINT64) { + temp_tensor_proto.add_uint64_val(x); + } else { + CHECK(false) << "Unsupported dtype: " << dtype; + } + } + Tensor tensor(dtype); + CHECK(tensor.FromProto(temp_tensor_proto)); + if (tensor_content) { + tensor.AsProtoTensorContent(tensor_proto); + } else { + tensor.AsProtoField(tensor_proto); + } +} + +void ExpectTensorShape(const std::vector<int64>& expected, + const TensorShapeProto& tensor_shape_proto) { + TensorShape tensor_shape_expected(expected); + TensorShape tensor_shape(tensor_shape_proto); + + LOG(INFO) << "Expected: " << tensor_shape_expected.DebugString(); + LOG(INFO) << "TensorShape: " << tensor_shape.DebugString(); + EXPECT_TRUE(tensor_shape_expected == tensor_shape); +} + +TEST_F(OpLevelCostEstimatorTest, GetTensorShapeProtoFromTensorProto) { + TensorProto tensor_proto; + TensorShapeProto tensor_shape_proto; + + // Dimention larger than max value; should fail while converting to Tensor + // class. + tensor_proto.mutable_tensor_shape()->add_dim()->set_size(255); + EXPECT_FALSE( + GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto)); + + tensor_proto.Clear(); + // Expect only 1D shape. + tensor_proto.mutable_tensor_shape()->add_dim()->set_size(1); + tensor_proto.mutable_tensor_shape()->add_dim()->set_size(2); + EXPECT_FALSE( + GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto)); + + // Expect only handle integer data types. + GetTensorProto(DT_FLOAT, {}, {}, /*tensor_content=*/false, &tensor_proto); + EXPECT_FALSE( + GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto)); + + // Check GetTensorShapeProtoFromTensorProto() resturns correct values. + { + std::vector<int64> shape_expected = {10, 20, 30, 40}; + GetTensorProto(DT_INT32, {4}, shape_expected, /*tensor_content=*/false, + &tensor_proto); + EXPECT_TRUE( + GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto)); + ExpectTensorShape(shape_expected, tensor_shape_proto); + } + + { + std::vector<int64> shape_expected = {40, 20, 90, 40}; + GetTensorProto(DT_INT64, {4}, shape_expected, /*tensor_content=*/false, + &tensor_proto); + EXPECT_TRUE( + GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto)); + ExpectTensorShape(shape_expected, tensor_shape_proto); + } + + { + std::vector<int64> shape_expected = {10, 20, 30, 40}; + GetTensorProto(DT_INT32, {4}, shape_expected, /*tensor_content=*/true, + &tensor_proto); + EXPECT_TRUE( + GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto)); + ExpectTensorShape(shape_expected, tensor_shape_proto); + } + + { + std::vector<int64> shape_expected = {40, 20, 90, 40}; + GetTensorProto(DT_INT64, {4}, shape_expected, /*tensor_content=*/true, + &tensor_proto); + EXPECT_TRUE( + GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto)); + ExpectTensorShape(shape_expected, tensor_shape_proto); + } +} + } // end namespace grappler } // end namespace tensorflow |