diff options
Diffstat (limited to 'tensorflow/core/grappler/costs/graph_properties_test.cc')
-rw-r--r-- | tensorflow/core/grappler/costs/graph_properties_test.cc | 98 |
1 files changed, 98 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc index 32683644fb..94b809dc44 100644 --- a/tensorflow/core/grappler/costs/graph_properties_test.cc +++ b/tensorflow/core/grappler/costs/graph_properties_test.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/grappler/clusters/single_machine.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" @@ -129,6 +132,101 @@ TEST_F(GraphPropertiesTest, DynamicProperties) { } } +TEST_F(GraphPropertiesTest, VarHandles) { + GrapplerItem item; + TF_CHECK_OK(NodeDefBuilder("Var", "VarHandleOp") + .Attr("dtype", DT_FLOAT) + .Attr("shape", TensorShape({3, 7})) + .Finalize(item.graph.add_node())); + + TF_CHECK_OK(NodeDefBuilder("VarRead", "ReadVariableOp") + .Attr("dtype", DT_FLOAT) + .Input("Var", 0, DT_RESOURCE) + .Finalize(item.graph.add_node())); + + GraphProperties properties(item); + TF_CHECK_OK(properties.InferStatically()); + + const auto props = properties.GetOutputProperties("VarRead"); + EXPECT_EQ(1, props.size()); + const OpInfo::TensorProperties& prop = props[0]; + EXPECT_EQ(DT_FLOAT, prop.dtype()); + EXPECT_FALSE(prop.shape().unknown_rank()); + EXPECT_EQ(2, prop.shape().dim_size()); + EXPECT_EQ(3, prop.shape().dim(0).size()); + EXPECT_EQ(7, prop.shape().dim(1).size()); +} + +TEST_F(GraphPropertiesTest, Queues) { + // Create a graph with known input shapes, and propagate the shapes through a + // couple of queues. + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + + auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT}); + Output rnd = + ops::RandomNormal(root.WithOpName("rnd"), {3, 7}, DataType::DT_FLOAT); + Output square1 = ops::Square(root.WithOpName("Square1"), rnd); + auto enqueue1 = ops::QueueEnqueue(root.WithOpName("Enqueue1"), q1, {square1}); + auto dequeue1 = + ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT}); + + auto q2 = + ops::RandomShuffleQueue(root.WithOpName("Queue2"), {DataType::DT_FLOAT}); + Output square2 = ops::Square(root.WithOpName("Square2"), dequeue1[0]); + auto enqueue2 = ops::QueueEnqueue(root.WithOpName("Enqueue2"), q2, {square2}); + auto dequeue2 = + ops::QueueDequeue(root.WithOpName("Dequeue2"), q2, {DataType::DT_FLOAT}); + + auto q3 = + ops::RandomShuffleQueue(root.WithOpName("Queue3"), {DataType::DT_FLOAT}); + auto dequeue3 = + ops::QueueDequeue(root.WithOpName("Dequeue3"), q3, {DataType::DT_FLOAT}); + + auto q4 = + ops::RandomShuffleQueue(root.WithOpName("Queue4"), {DataType::DT_FLOAT}); + auto enqueue4 = ops::QueueEnqueue(root.WithOpName("Enqueue4"), q4, {square2}); + auto enqueue4_2 = + ops::QueueEnqueue(root.WithOpName("Enqueue4_2"), q4, {dequeue3[0]}); + auto dequeue4 = + ops::QueueDequeue(root.WithOpName("Dequeue4"), q4, {DataType::DT_FLOAT}); + + GrapplerItem item; + TF_CHECK_OK(root.ToGraphDef(&item.graph)); + + GraphProperties properties(item); + TF_CHECK_OK(properties.InferStatically()); + + const auto props1 = properties.GetOutputProperties("Dequeue1"); + EXPECT_EQ(1, props1.size()); + const OpInfo::TensorProperties& prop1 = props1[0]; + EXPECT_EQ(DT_FLOAT, prop1.dtype()); + EXPECT_FALSE(prop1.shape().unknown_rank()); + EXPECT_EQ(2, prop1.shape().dim_size()); + EXPECT_EQ(3, prop1.shape().dim(0).size()); + EXPECT_EQ(7, prop1.shape().dim(1).size()); + + const auto props2 = properties.GetOutputProperties("Dequeue2"); + EXPECT_EQ(1, props2.size()); + const OpInfo::TensorProperties& prop2 = props2[0]; + EXPECT_EQ(DT_FLOAT, prop2.dtype()); + EXPECT_FALSE(prop2.shape().unknown_rank()); + EXPECT_EQ(2, prop2.shape().dim_size()); + EXPECT_EQ(3, prop2.shape().dim(0).size()); + EXPECT_EQ(7, prop2.shape().dim(1).size()); + + // The dequeue3 op shape is unknown. The square2 op shape is known. Verify + // that we merge the 2 properly to determine the shape of the data coming out + // of the queue. + const auto props4 = properties.GetOutputProperties("Dequeue4"); + EXPECT_EQ(1, props4.size()); + const OpInfo::TensorProperties& prop4 = props4[0]; + EXPECT_EQ(DT_FLOAT, prop4.dtype()); + EXPECT_FALSE(prop4.shape().unknown_rank()); + EXPECT_EQ(2, prop4.shape().dim_size()); + EXPECT_EQ(3, prop4.shape().dim(0).size()); + EXPECT_EQ(7, prop4.shape().dim(1).size()); +} + } // namespace } // namespace grappler } // namespace tensorflow |