aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/graph_properties_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/costs/graph_properties_test.cc')
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_test.cc98
1 files changed, 98 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc
index 32683644fb..94b809dc44 100644
--- a/tensorflow/core/grappler/costs/graph_properties_test.cc
+++ b/tensorflow/core/grappler/costs/graph_properties_test.cc
@@ -14,6 +14,9 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/costs/graph_properties.h"
+#include "tensorflow/cc/framework/scope.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/grappler/clusters/single_machine.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
@@ -129,6 +132,101 @@ TEST_F(GraphPropertiesTest, DynamicProperties) {
}
}
+TEST_F(GraphPropertiesTest, VarHandles) {
+ GrapplerItem item;
+ TF_CHECK_OK(NodeDefBuilder("Var", "VarHandleOp")
+ .Attr("dtype", DT_FLOAT)
+ .Attr("shape", TensorShape({3, 7}))
+ .Finalize(item.graph.add_node()));
+
+ TF_CHECK_OK(NodeDefBuilder("VarRead", "ReadVariableOp")
+ .Attr("dtype", DT_FLOAT)
+ .Input("Var", 0, DT_RESOURCE)
+ .Finalize(item.graph.add_node()));
+
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically());
+
+ const auto props = properties.GetOutputProperties("VarRead");
+ EXPECT_EQ(1, props.size());
+ const OpInfo::TensorProperties& prop = props[0];
+ EXPECT_EQ(DT_FLOAT, prop.dtype());
+ EXPECT_FALSE(prop.shape().unknown_rank());
+ EXPECT_EQ(2, prop.shape().dim_size());
+ EXPECT_EQ(3, prop.shape().dim(0).size());
+ EXPECT_EQ(7, prop.shape().dim(1).size());
+}
+
+TEST_F(GraphPropertiesTest, Queues) {
+ // Create a graph with known input shapes, and propagate the shapes through a
+ // couple of queues.
+ tensorflow::Scope root = tensorflow::Scope::NewRootScope();
+
+ auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT});
+ Output rnd =
+ ops::RandomNormal(root.WithOpName("rnd"), {3, 7}, DataType::DT_FLOAT);
+ Output square1 = ops::Square(root.WithOpName("Square1"), rnd);
+ auto enqueue1 = ops::QueueEnqueue(root.WithOpName("Enqueue1"), q1, {square1});
+ auto dequeue1 =
+ ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
+
+ auto q2 =
+ ops::RandomShuffleQueue(root.WithOpName("Queue2"), {DataType::DT_FLOAT});
+ Output square2 = ops::Square(root.WithOpName("Square2"), dequeue1[0]);
+ auto enqueue2 = ops::QueueEnqueue(root.WithOpName("Enqueue2"), q2, {square2});
+ auto dequeue2 =
+ ops::QueueDequeue(root.WithOpName("Dequeue2"), q2, {DataType::DT_FLOAT});
+
+ auto q3 =
+ ops::RandomShuffleQueue(root.WithOpName("Queue3"), {DataType::DT_FLOAT});
+ auto dequeue3 =
+ ops::QueueDequeue(root.WithOpName("Dequeue3"), q3, {DataType::DT_FLOAT});
+
+ auto q4 =
+ ops::RandomShuffleQueue(root.WithOpName("Queue4"), {DataType::DT_FLOAT});
+ auto enqueue4 = ops::QueueEnqueue(root.WithOpName("Enqueue4"), q4, {square2});
+ auto enqueue4_2 =
+ ops::QueueEnqueue(root.WithOpName("Enqueue4_2"), q4, {dequeue3[0]});
+ auto dequeue4 =
+ ops::QueueDequeue(root.WithOpName("Dequeue4"), q4, {DataType::DT_FLOAT});
+
+ GrapplerItem item;
+ TF_CHECK_OK(root.ToGraphDef(&item.graph));
+
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically());
+
+ const auto props1 = properties.GetOutputProperties("Dequeue1");
+ EXPECT_EQ(1, props1.size());
+ const OpInfo::TensorProperties& prop1 = props1[0];
+ EXPECT_EQ(DT_FLOAT, prop1.dtype());
+ EXPECT_FALSE(prop1.shape().unknown_rank());
+ EXPECT_EQ(2, prop1.shape().dim_size());
+ EXPECT_EQ(3, prop1.shape().dim(0).size());
+ EXPECT_EQ(7, prop1.shape().dim(1).size());
+
+ const auto props2 = properties.GetOutputProperties("Dequeue2");
+ EXPECT_EQ(1, props2.size());
+ const OpInfo::TensorProperties& prop2 = props2[0];
+ EXPECT_EQ(DT_FLOAT, prop2.dtype());
+ EXPECT_FALSE(prop2.shape().unknown_rank());
+ EXPECT_EQ(2, prop2.shape().dim_size());
+ EXPECT_EQ(3, prop2.shape().dim(0).size());
+ EXPECT_EQ(7, prop2.shape().dim(1).size());
+
+ // The dequeue3 op shape is unknown. The square2 op shape is known. Verify
+ // that we merge the 2 properly to determine the shape of the data coming out
+ // of the queue.
+ const auto props4 = properties.GetOutputProperties("Dequeue4");
+ EXPECT_EQ(1, props4.size());
+ const OpInfo::TensorProperties& prop4 = props4[0];
+ EXPECT_EQ(DT_FLOAT, prop4.dtype());
+ EXPECT_FALSE(prop4.shape().unknown_rank());
+ EXPECT_EQ(2, prop4.shape().dim_size());
+ EXPECT_EQ(3, prop4.shape().dim(0).size());
+ EXPECT_EQ(7, prop4.shape().dim(1).size());
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow