diff options
author | Benoit Steiner <bsteiner@google.com> | 2017-03-27 11:23:53 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-27 12:48:31 -0700 |
commit | 94f610c1f2861744d62820153dd10839458978e3 (patch) | |
tree | 1e03d2cdfc57adb6e8f87b7c6c8ecbb45bc9f862 | |
parent | ca6b88eb95b089010a1b970e0de7398195b5bcca (diff) |
Feed the test input with random numbers instead of constants to prevent
constant folding from wiping out a large part of the test network
Change: 151354452
3 files changed, 19 insertions, 11 deletions
diff --git a/tensorflow/core/grappler/costs/graph_memory_test.cc b/tensorflow/core/grappler/costs/graph_memory_test.cc index a3c58a1d76..82c86064c6 100644 --- a/tensorflow/core/grappler/costs/graph_memory_test.cc +++ b/tensorflow/core/grappler/costs/graph_memory_test.cc @@ -32,7 +32,11 @@ TEST_F(GraphMemoryTest, Basic) { GraphMemory memory(item); Status s = memory.InferStatically(); TF_CHECK_OK(s); - EXPECT_EQ(240, memory.GetWorstCaseMemoryUsage()); + // 5 AddN + 1 random op each generating 10 values -> 240 bytes + // 4 more bytes for the mean of the distribution and 4 more for the stddev. + EXPECT_EQ(248, memory.GetWorstCaseMemoryUsage()); + // If at most one op executes at a time, it needs 10 inputs values and 10 + // output values, or 8 bytes. EXPECT_EQ(80, memory.GetBestCaseMemoryUsage()); } @@ -44,8 +48,10 @@ TEST_F(GraphMemoryTest, UnknownBatchSize) { GraphMemory memory(item); Status s = memory.InferStatically(); TF_CHECK_OK(s); - EXPECT_EQ(24, memory.GetWorstCaseMemoryUsage()); - EXPECT_EQ(8, memory.GetBestCaseMemoryUsage()); + // Same maths as before, except that batch size is unknown and therefore + // assumed to be one. + EXPECT_EQ(32, memory.GetWorstCaseMemoryUsage()); + EXPECT_EQ(12, memory.GetBestCaseMemoryUsage()); } } // namespace diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc index d2f448e6d3..32683644fb 100644 --- a/tensorflow/core/grappler/costs/graph_properties_test.cc +++ b/tensorflow/core/grappler/costs/graph_properties_test.cc @@ -49,9 +49,9 @@ TEST_F(GraphPropertiesTest, StaticProperties) { TF_CHECK_OK(s); for (const auto& node : item.graph.node()) { - if (node.op() == "Const") { - // The const node has no input. - EXPECT_EQ(0, properties.GetInputProperties(node.name()).size()); + if (node.op() == "RandomStandardNormal") { + // The node has one input (the shape of the tensor to generate). + EXPECT_EQ(1, properties.GetInputProperties(node.name()).size()); // The const node has one output. const auto props = properties.GetOutputProperties(node.name()); EXPECT_EQ(1, props.size()); @@ -94,12 +94,13 @@ TEST_F(GraphPropertiesTest, DynamicProperties) { TF_CHECK_OK(s); for (const auto& node : item.graph.node()) { - if (node.op() == "Const") { - // The constant node is missing from the cost graph + if (node.op() == "RandomStandardNormal") { + // The random node is missing from the cost graph (why ?) EXPECT_EQ(0, properties.GetInputProperties(node.name()).size()); } else if (node.op() == "AddN") { - // Since the const node is missing, we can't infer the input properties of - // the first AddN node. THe other AddN have the expected properties + // Since the random node is missing, we can't infer the input properties + // of the first AddN node. The other AddN nodes have the expected + // properties. if (node.name() == "AddN") { const auto props = properties.GetInputProperties(node.name()); EXPECT_EQ(1, props.size()); diff --git a/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc b/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc index 8370133fc4..446ae2df64 100644 --- a/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc +++ b/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc @@ -39,7 +39,8 @@ GraphDef CreateGraphDef(int num_stages, int width, int tensor_size, // x is from the feed. const int batch_size = tensor_size < 0 ? 1 : tensor_size; - Output x = Const(s.WithOpName("x"), 0.0f, {batch_size, 1}); + Output x = + RandomNormal(s.WithOpName("x"), {batch_size, 1}, DataType::DT_FLOAT); // Create stages. std::vector<Output> last_stage; |