aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2017-06-05 13:22:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-05 13:26:27 -0700
commit0df6760fe9e8c96e5d5396745a82e06f6a3737ec (patch)
treebf3bfca7c4a68c814cb43d3269f309e52903e903
parent2ccfe8e764632cd05422bda12abe0f7a24abf000 (diff)
Added a test to make sure that graph properties for variables are properly
reported PiperOrigin-RevId: 158053084
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_test.cc48
1 files changed, 48 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc
index 6eca083184..bff5f7acc5 100644
--- a/tensorflow/core/grappler/costs/graph_properties_test.cc
+++ b/tensorflow/core/grappler/costs/graph_properties_test.cc
@@ -149,6 +149,54 @@ TEST_F(GraphPropertiesTest, DynamicProperties) {
}
}
+TEST_F(GraphPropertiesTest, Variables) {
+ GrapplerItem item;
+ TF_CHECK_OK(NodeDefBuilder("Var", "Variable")
+ .Attr("dtype", DT_FLOAT)
+ .Attr("shape", TensorShape({3, 7}))
+ .Finalize(item.graph.add_node()));
+ item.fetch.push_back("Var");
+
+ Tensor initial_val(DT_FLOAT, TensorShape({3, 7}));
+ TF_CHECK_OK(NodeDefBuilder("InitialVal", "Const")
+ .Attr("dtype", DT_FLOAT)
+ .Attr("value", initial_val)
+ .Finalize(item.graph.add_node()));
+ TF_CHECK_OK(NodeDefBuilder("InitVar", "Assign")
+ .Input("Var", 0, DT_FLOAT_REF)
+ .Input("InitialVal", 0, DT_FLOAT)
+ .Finalize(item.graph.add_node()));
+ item.init_ops.push_back("InitVar");
+
+ {
+ GraphProperties static_properties(item);
+ TF_CHECK_OK(static_properties.InferStatically());
+
+ const auto props = static_properties.GetOutputProperties("Var");
+ EXPECT_EQ(1, props.size());
+ const OpInfo::TensorProperties& prop = props[0];
+ EXPECT_EQ(DT_FLOAT_REF, prop.dtype());
+ EXPECT_FALSE(prop.shape().unknown_rank());
+ EXPECT_EQ(2, prop.shape().dim_size());
+ EXPECT_EQ(3, prop.shape().dim(0).size());
+ EXPECT_EQ(7, prop.shape().dim(1).size());
+ }
+ {
+ TF_CHECK_OK(cluster_->Initialize(item));
+ GraphProperties dynamic_properties(item);
+ TF_CHECK_OK(dynamic_properties.InferDynamically(cluster_.get()));
+
+ const auto props = dynamic_properties.GetOutputProperties("Var");
+ EXPECT_EQ(1, props.size());
+ const OpInfo::TensorProperties& prop = props[0];
+ EXPECT_EQ(DT_FLOAT_REF, prop.dtype());
+ EXPECT_FALSE(prop.shape().unknown_rank());
+ EXPECT_EQ(2, prop.shape().dim_size());
+ EXPECT_EQ(3, prop.shape().dim(0).size());
+ EXPECT_EQ(7, prop.shape().dim(1).size());
+ }
+}
+
TEST_F(GraphPropertiesTest, VarHandles) {
GrapplerItem item;
TF_CHECK_OK(NodeDefBuilder("Var", "VarHandleOp")