aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/utils_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/costs/utils_test.cc')
-rw-r--r--tensorflow/core/grappler/costs/utils_test.cc112
1 files changed, 83 insertions, 29 deletions
diff --git a/tensorflow/core/grappler/costs/utils_test.cc b/tensorflow/core/grappler/costs/utils_test.cc
index baa654f475..db5c11f0fe 100644
--- a/tensorflow/core/grappler/costs/utils_test.cc
+++ b/tensorflow/core/grappler/costs/utils_test.cc
@@ -26,36 +26,42 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-class UtilsTest : public ::testing::Test {
- public:
- void CreateConstOp(const string& name, std::initializer_list<int64> dims,
- NodeDef* node) {
- Tensor tensor(DT_FLOAT, TensorShape(dims));
- for (int64 i = 0; i < tensor.NumElements(); ++i) {
- tensor.flat<float>()(i) = i / 10.0f;
- }
- TF_CHECK_OK(NodeDefBuilder(name, "Const")
- .Attr("dtype", DT_FLOAT)
- .Attr("value", tensor)
- .Finalize(node));
- }
+namespace {
- void CreateConstSizesOp(const string& name, const std::vector<int32>& sizes,
- NodeDef* node) {
- TensorShape shape;
- shape.AddDim(sizes.size());
- Tensor tensor(DT_INT32, shape);
- for (int64 i = 0; i < tensor.NumElements(); ++i) {
- tensor.flat<int32>()(i) = sizes[i];
- }
- TF_CHECK_OK(NodeDefBuilder(name, "Const")
- .Attr("dtype", DT_INT32)
- .Attr("value", tensor)
- .Finalize(node));
- }
-};
+void CreateConstOp(const string& name, std::initializer_list<int64> dims,
+ NodeDef* node) {
+ Tensor tensor(DT_FLOAT, TensorShape(dims));
+ for (int64 i = 0; i < tensor.NumElements(); ++i)
+ tensor.flat<float>()(i) = i / 10.0f;
+ TF_CHECK_OK(NodeDefBuilder(name, "Const")
+ .Attr("dtype", DT_FLOAT)
+ .Attr("value", tensor)
+ .Finalize(node));
+}
-TEST_F(UtilsTest, ConvOpInfo) {
+void CreateConstSizesOp(const string& name, const std::vector<int32>& sizes,
+ NodeDef* node) {
+ TensorShape shape;
+ shape.AddDim(sizes.size());
+ Tensor tensor(DT_INT32, shape);
+ for (int64 i = 0; i < tensor.NumElements(); ++i)
+ tensor.flat<int32>()(i) = sizes[i];
+ TF_CHECK_OK(NodeDefBuilder(name, "Const")
+ .Attr("dtype", DT_INT32)
+ .Attr("value", tensor)
+ .Finalize(node));
+}
+
+// Helper method for converting shapes vector to TensorProperty.
+OpInfo::TensorProperties ShapeToTensorProperty(const std::vector<int>& shapes,
+ const DataType& data_type) {
+ OpInfo::TensorProperties prop;
+ prop.set_dtype(data_type);
+ for (int shape : shapes) prop.mutable_shape()->add_dim()->set_size(shape);
+ return prop;
+}
+
+TEST(UtilsTest, ConvOpInfo) {
int batch = 32;
int rows = 7;
int cols = 9;
@@ -146,7 +152,7 @@ TEST_F(UtilsTest, ConvOpInfo) {
}
}
-TEST_F(UtilsTest, TestSkipControlInput) {
+TEST(UtilsTest, TestSkipControlInput) {
GraphDef graph;
TF_CHECK_OK(NodeDefBuilder("constant", "Const")
.Attr("dtype", DT_INT32)
@@ -172,6 +178,52 @@ TEST_F(UtilsTest, TestSkipControlInput) {
EXPECT_TRUE(node_found);
}
+TEST(UtilsTest, CalculateTensorSize) {
+ // Test normal usage.
+ EXPECT_EQ(DataTypeSize(DT_FLOAT) * 1,
+ CalculateTensorSize(ShapeToTensorProperty({1}, DT_FLOAT)));
+ EXPECT_EQ(DataTypeSize(DT_FLOAT) * 4 * 4,
+ CalculateTensorSize(ShapeToTensorProperty({4, 4}, DT_FLOAT)));
+ EXPECT_EQ(DataTypeSize(DT_HALF) * 10 * 10 * 10,
+ CalculateTensorSize(ShapeToTensorProperty({10, 10, 10}, DT_HALF)));
+ EXPECT_EQ(
+ DataTypeSize(DT_FLOAT) * 100 * 7 * 8 * 99,
+ CalculateTensorSize(ShapeToTensorProperty({100, 7, 8, 99}, DT_FLOAT)));
+
+ // Test unknown rank: assumes the tensor to be a scalar.
+ OpInfo::TensorProperties t = ShapeToTensorProperty({100, 7, 8, 99}, DT_FLOAT);
+ t.mutable_shape()->set_unknown_rank(true);
+ EXPECT_EQ(DataTypeSize(DT_FLOAT) * 1, CalculateTensorSize(t));
+
+ // Test unknown shape: assumes unknown shape (-1) to have size 1.
+ EXPECT_EQ(
+ DataTypeSize(DT_FLOAT) * 1 * 7 * 8 * 99,
+ CalculateTensorSize(ShapeToTensorProperty({-1, 7, 8, 99}, DT_FLOAT)));
+ EXPECT_EQ(
+ DataTypeSize(DT_FLOAT) * 1 * 7 * 1 * 99,
+ CalculateTensorSize(ShapeToTensorProperty({-1, 7, -1, 99}, DT_FLOAT)));
+}
+
+TEST(UtilsTest, CalculateOutputSize) {
+ // Create a set of tensor properties.
+ std::vector<OpInfo::TensorProperties> output = {
+ ShapeToTensorProperty({4, 4}, DT_FLOAT), // 0
+ ShapeToTensorProperty({-1, 7, -1, 99}, DT_FLOAT) // 1
+ };
+
+ // Test valid outputs.
+ EXPECT_EQ(DataTypeSize(DT_FLOAT) * 4 * 4, CalculateOutputSize(output, 0));
+ EXPECT_EQ(DataTypeSize(DT_FLOAT) * 1 * 7 * 1 * 99,
+ CalculateOutputSize(output, 1));
+
+ // port_num -1 is for control dependency: hard coded 4B.
+ EXPECT_EQ(4, CalculateOutputSize(output, -1));
+
+ // Invalid port_num (though it may be an error) shall yield zero
+ // output size.
+ EXPECT_EQ(0, CalculateOutputSize(output, 2));
+}
+
// Class for testing TensorSizeHistogram.
class TestTensorSizeHistogram : public TensorSizeHistogram {
public:
@@ -285,5 +337,7 @@ TEST(DeviceClassTest, GetDeviceClassForNonChannelDevice) {
EXPECT_EQ("//GPU", GetDeviceClassForNonChannelDevice("/device:GPU:7"));
}
+} // namespace
+
} // end namespace grappler
} // end namespace tensorflow