diff options
Diffstat (limited to 'tensorflow/core/grappler/costs/utils_test.cc')
-rw-r--r-- | tensorflow/core/grappler/costs/utils_test.cc | 112 |
1 files changed, 83 insertions, 29 deletions
diff --git a/tensorflow/core/grappler/costs/utils_test.cc b/tensorflow/core/grappler/costs/utils_test.cc index baa654f475..db5c11f0fe 100644 --- a/tensorflow/core/grappler/costs/utils_test.cc +++ b/tensorflow/core/grappler/costs/utils_test.cc @@ -26,36 +26,42 @@ limitations under the License. namespace tensorflow { namespace grappler { -class UtilsTest : public ::testing::Test { - public: - void CreateConstOp(const string& name, std::initializer_list<int64> dims, - NodeDef* node) { - Tensor tensor(DT_FLOAT, TensorShape(dims)); - for (int64 i = 0; i < tensor.NumElements(); ++i) { - tensor.flat<float>()(i) = i / 10.0f; - } - TF_CHECK_OK(NodeDefBuilder(name, "Const") - .Attr("dtype", DT_FLOAT) - .Attr("value", tensor) - .Finalize(node)); - } +namespace { - void CreateConstSizesOp(const string& name, const std::vector<int32>& sizes, - NodeDef* node) { - TensorShape shape; - shape.AddDim(sizes.size()); - Tensor tensor(DT_INT32, shape); - for (int64 i = 0; i < tensor.NumElements(); ++i) { - tensor.flat<int32>()(i) = sizes[i]; - } - TF_CHECK_OK(NodeDefBuilder(name, "Const") - .Attr("dtype", DT_INT32) - .Attr("value", tensor) - .Finalize(node)); - } -}; +void CreateConstOp(const string& name, std::initializer_list<int64> dims, + NodeDef* node) { + Tensor tensor(DT_FLOAT, TensorShape(dims)); + for (int64 i = 0; i < tensor.NumElements(); ++i) + tensor.flat<float>()(i) = i / 10.0f; + TF_CHECK_OK(NodeDefBuilder(name, "Const") + .Attr("dtype", DT_FLOAT) + .Attr("value", tensor) + .Finalize(node)); +} -TEST_F(UtilsTest, ConvOpInfo) { +void CreateConstSizesOp(const string& name, const std::vector<int32>& sizes, + NodeDef* node) { + TensorShape shape; + shape.AddDim(sizes.size()); + Tensor tensor(DT_INT32, shape); + for (int64 i = 0; i < tensor.NumElements(); ++i) + tensor.flat<int32>()(i) = sizes[i]; + TF_CHECK_OK(NodeDefBuilder(name, "Const") + .Attr("dtype", DT_INT32) + .Attr("value", tensor) + .Finalize(node)); +} + +// Helper method for converting shapes vector to TensorProperty. +OpInfo::TensorProperties ShapeToTensorProperty(const std::vector<int>& shapes, + const DataType& data_type) { + OpInfo::TensorProperties prop; + prop.set_dtype(data_type); + for (int shape : shapes) prop.mutable_shape()->add_dim()->set_size(shape); + return prop; +} + +TEST(UtilsTest, ConvOpInfo) { int batch = 32; int rows = 7; int cols = 9; @@ -146,7 +152,7 @@ TEST_F(UtilsTest, ConvOpInfo) { } } -TEST_F(UtilsTest, TestSkipControlInput) { +TEST(UtilsTest, TestSkipControlInput) { GraphDef graph; TF_CHECK_OK(NodeDefBuilder("constant", "Const") .Attr("dtype", DT_INT32) @@ -172,6 +178,52 @@ TEST_F(UtilsTest, TestSkipControlInput) { EXPECT_TRUE(node_found); } +TEST(UtilsTest, CalculateTensorSize) { + // Test normal usage. + EXPECT_EQ(DataTypeSize(DT_FLOAT) * 1, + CalculateTensorSize(ShapeToTensorProperty({1}, DT_FLOAT))); + EXPECT_EQ(DataTypeSize(DT_FLOAT) * 4 * 4, + CalculateTensorSize(ShapeToTensorProperty({4, 4}, DT_FLOAT))); + EXPECT_EQ(DataTypeSize(DT_HALF) * 10 * 10 * 10, + CalculateTensorSize(ShapeToTensorProperty({10, 10, 10}, DT_HALF))); + EXPECT_EQ( + DataTypeSize(DT_FLOAT) * 100 * 7 * 8 * 99, + CalculateTensorSize(ShapeToTensorProperty({100, 7, 8, 99}, DT_FLOAT))); + + // Test unknown rank: assumes the tensor to be a scalar. + OpInfo::TensorProperties t = ShapeToTensorProperty({100, 7, 8, 99}, DT_FLOAT); + t.mutable_shape()->set_unknown_rank(true); + EXPECT_EQ(DataTypeSize(DT_FLOAT) * 1, CalculateTensorSize(t)); + + // Test unknown shape: assumes unknown shape (-1) to have size 1. + EXPECT_EQ( + DataTypeSize(DT_FLOAT) * 1 * 7 * 8 * 99, + CalculateTensorSize(ShapeToTensorProperty({-1, 7, 8, 99}, DT_FLOAT))); + EXPECT_EQ( + DataTypeSize(DT_FLOAT) * 1 * 7 * 1 * 99, + CalculateTensorSize(ShapeToTensorProperty({-1, 7, -1, 99}, DT_FLOAT))); +} + +TEST(UtilsTest, CalculateOutputSize) { + // Create a set of tensor properties. + std::vector<OpInfo::TensorProperties> output = { + ShapeToTensorProperty({4, 4}, DT_FLOAT), // 0 + ShapeToTensorProperty({-1, 7, -1, 99}, DT_FLOAT) // 1 + }; + + // Test valid outputs. + EXPECT_EQ(DataTypeSize(DT_FLOAT) * 4 * 4, CalculateOutputSize(output, 0)); + EXPECT_EQ(DataTypeSize(DT_FLOAT) * 1 * 7 * 1 * 99, + CalculateOutputSize(output, 1)); + + // port_num -1 is for control dependency: hard coded 4B. + EXPECT_EQ(4, CalculateOutputSize(output, -1)); + + // Invalid port_num (though it may be an error) shall yield zero + // output size. + EXPECT_EQ(0, CalculateOutputSize(output, 2)); +} + // Class for testing TensorSizeHistogram. class TestTensorSizeHistogram : public TensorSizeHistogram { public: @@ -285,5 +337,7 @@ TEST(DeviceClassTest, GetDeviceClassForNonChannelDevice) { EXPECT_EQ("//GPU", GetDeviceClassForNonChannelDevice("/device:GPU:7")); } +} // namespace + } // end namespace grappler } // end namespace tensorflow |