diff options
author | Benoit Steiner <bsteiner@google.com> | 2017-06-05 13:22:41 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-05 13:26:27 -0700 |
commit | 0df6760fe9e8c96e5d5396745a82e06f6a3737ec (patch) | |
tree | bf3bfca7c4a68c814cb43d3269f309e52903e903 | |
parent | 2ccfe8e764632cd05422bda12abe0f7a24abf000 (diff) |
Added a test to make sure that graph properties for variables are properly
reported
PiperOrigin-RevId: 158053084
-rw-r--r-- | tensorflow/core/grappler/costs/graph_properties_test.cc | 48 |
1 files changed, 48 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc index 6eca083184..bff5f7acc5 100644 --- a/tensorflow/core/grappler/costs/graph_properties_test.cc +++ b/tensorflow/core/grappler/costs/graph_properties_test.cc @@ -149,6 +149,54 @@ TEST_F(GraphPropertiesTest, DynamicProperties) { } } +TEST_F(GraphPropertiesTest, Variables) { + GrapplerItem item; + TF_CHECK_OK(NodeDefBuilder("Var", "Variable") + .Attr("dtype", DT_FLOAT) + .Attr("shape", TensorShape({3, 7})) + .Finalize(item.graph.add_node())); + item.fetch.push_back("Var"); + + Tensor initial_val(DT_FLOAT, TensorShape({3, 7})); + TF_CHECK_OK(NodeDefBuilder("InitialVal", "Const") + .Attr("dtype", DT_FLOAT) + .Attr("value", initial_val) + .Finalize(item.graph.add_node())); + TF_CHECK_OK(NodeDefBuilder("InitVar", "Assign") + .Input("Var", 0, DT_FLOAT_REF) + .Input("InitialVal", 0, DT_FLOAT) + .Finalize(item.graph.add_node())); + item.init_ops.push_back("InitVar"); + + { + GraphProperties static_properties(item); + TF_CHECK_OK(static_properties.InferStatically()); + + const auto props = static_properties.GetOutputProperties("Var"); + EXPECT_EQ(1, props.size()); + const OpInfo::TensorProperties& prop = props[0]; + EXPECT_EQ(DT_FLOAT_REF, prop.dtype()); + EXPECT_FALSE(prop.shape().unknown_rank()); + EXPECT_EQ(2, prop.shape().dim_size()); + EXPECT_EQ(3, prop.shape().dim(0).size()); + EXPECT_EQ(7, prop.shape().dim(1).size()); + } + { + TF_CHECK_OK(cluster_->Initialize(item)); + GraphProperties dynamic_properties(item); + TF_CHECK_OK(dynamic_properties.InferDynamically(cluster_.get())); + + const auto props = dynamic_properties.GetOutputProperties("Var"); + EXPECT_EQ(1, props.size()); + const OpInfo::TensorProperties& prop = props[0]; + EXPECT_EQ(DT_FLOAT_REF, prop.dtype()); + EXPECT_FALSE(prop.shape().unknown_rank()); + EXPECT_EQ(2, prop.shape().dim_size()); + EXPECT_EQ(3, prop.shape().dim(0).size()); + EXPECT_EQ(7, prop.shape().dim(1).size()); + } +} + TEST_F(GraphPropertiesTest, VarHandles) { GrapplerItem item; TF_CHECK_OK(NodeDefBuilder("Var", "VarHandleOp") |