aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2017-03-27 11:23:53 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-27 12:48:31 -0700
commit94f610c1f2861744d62820153dd10839458978e3 (patch)
tree1e03d2cdfc57adb6e8f87b7c6c8ecbb45bc9f862
parentca6b88eb95b089010a1b970e0de7398195b5bcca (diff)
Feed the test input with random numbers instead of constants to prevent
constant folding from wiping out a large part of the test network Change: 151354452
-rw-r--r--tensorflow/core/grappler/costs/graph_memory_test.cc12
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_test.cc15
-rw-r--r--tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc3
3 files changed, 19 insertions, 11 deletions
diff --git a/tensorflow/core/grappler/costs/graph_memory_test.cc b/tensorflow/core/grappler/costs/graph_memory_test.cc
index a3c58a1d76..82c86064c6 100644
--- a/tensorflow/core/grappler/costs/graph_memory_test.cc
+++ b/tensorflow/core/grappler/costs/graph_memory_test.cc
@@ -32,7 +32,11 @@ TEST_F(GraphMemoryTest, Basic) {
GraphMemory memory(item);
Status s = memory.InferStatically();
TF_CHECK_OK(s);
- EXPECT_EQ(240, memory.GetWorstCaseMemoryUsage());
+ // 5 AddN + 1 random op each generating 10 values -> 240 bytes
+ // 4 more bytes for the mean of the distribution and 4 more for the stddev.
+ EXPECT_EQ(248, memory.GetWorstCaseMemoryUsage());
+ // If at most one op executes at a time, it needs 10 inputs values and 10
+ // output values, or 8 bytes.
EXPECT_EQ(80, memory.GetBestCaseMemoryUsage());
}
@@ -44,8 +48,10 @@ TEST_F(GraphMemoryTest, UnknownBatchSize) {
GraphMemory memory(item);
Status s = memory.InferStatically();
TF_CHECK_OK(s);
- EXPECT_EQ(24, memory.GetWorstCaseMemoryUsage());
- EXPECT_EQ(8, memory.GetBestCaseMemoryUsage());
+ // Same maths as before, except that batch size is unknown and therefore
+ // assumed to be one.
+ EXPECT_EQ(32, memory.GetWorstCaseMemoryUsage());
+ EXPECT_EQ(12, memory.GetBestCaseMemoryUsage());
}
} // namespace
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc
index d2f448e6d3..32683644fb 100644
--- a/tensorflow/core/grappler/costs/graph_properties_test.cc
+++ b/tensorflow/core/grappler/costs/graph_properties_test.cc
@@ -49,9 +49,9 @@ TEST_F(GraphPropertiesTest, StaticProperties) {
TF_CHECK_OK(s);
for (const auto& node : item.graph.node()) {
- if (node.op() == "Const") {
- // The const node has no input.
- EXPECT_EQ(0, properties.GetInputProperties(node.name()).size());
+ if (node.op() == "RandomStandardNormal") {
+ // The node has one input (the shape of the tensor to generate).
+ EXPECT_EQ(1, properties.GetInputProperties(node.name()).size());
// The const node has one output.
const auto props = properties.GetOutputProperties(node.name());
EXPECT_EQ(1, props.size());
@@ -94,12 +94,13 @@ TEST_F(GraphPropertiesTest, DynamicProperties) {
TF_CHECK_OK(s);
for (const auto& node : item.graph.node()) {
- if (node.op() == "Const") {
- // The constant node is missing from the cost graph
+ if (node.op() == "RandomStandardNormal") {
+ // The random node is missing from the cost graph (why ?)
EXPECT_EQ(0, properties.GetInputProperties(node.name()).size());
} else if (node.op() == "AddN") {
- // Since the const node is missing, we can't infer the input properties of
- // the first AddN node. THe other AddN have the expected properties
+ // Since the random node is missing, we can't infer the input properties
+ // of the first AddN node. The other AddN nodes have the expected
+ // properties.
if (node.name() == "AddN") {
const auto props = properties.GetInputProperties(node.name());
EXPECT_EQ(1, props.size());
diff --git a/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc b/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc
index 8370133fc4..446ae2df64 100644
--- a/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc
+++ b/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc
@@ -39,7 +39,8 @@ GraphDef CreateGraphDef(int num_stages, int width, int tensor_size,
// x is from the feed.
const int batch_size = tensor_size < 0 ? 1 : tensor_size;
- Output x = Const(s.WithOpName("x"), 0.0f, {batch_size, 1});
+ Output x =
+ RandomNormal(s.WithOpName("x"), {batch_size, 1}, DataType::DT_FLOAT);
// Create stages.
std::vector<Output> last_stage;