aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/shape_refiner_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/common_runtime/shape_refiner_test.cc')
-rw-r--r--tensorflow/core/common_runtime/shape_refiner_test.cc33
1 files changed, 33 insertions, 0 deletions
diff --git a/tensorflow/core/common_runtime/shape_refiner_test.cc b/tensorflow/core/common_runtime/shape_refiner_test.cc
index d7e7c3b5ad..b8df6dd4f6 100644
--- a/tensorflow/core/common_runtime/shape_refiner_test.cc
+++ b/tensorflow/core/common_runtime/shape_refiner_test.cc
@@ -768,5 +768,38 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_ConcatInvalidDimValue) {
m.AddNode(result).error_message());
}
+TEST(ShapeRefinerTest, IncrementalUpdates) {
+ Scope root = Scope::NewRootScope();
+ Graph* g = root.graph();
+ Node* queue;
+ TF_CHECK_OK(NodeBuilder("queue", "FIFOQueueV2")
+ .Attr("component_types", {DT_FLOAT})
+ .Finalize(g, &queue));
+ Node* dequeue;
+ TF_CHECK_OK(NodeBuilder("dequeue", "QueueDequeueV2")
+ .Attr("component_types", {DT_FLOAT})
+ .Input(queue)
+ .Finalize(g, &dequeue));
+ ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
+ TF_ASSERT_OK(m.AddNode(queue));
+ TF_ASSERT_OK(m.AddNode(dequeue));
+
+ // At this point, the shapes of the dequeued tensor are unknown.
+ shape_inference::InferenceContext* ctx = m.GetContext(dequeue);
+ EXPECT_EQ("?", ctx->DebugString(ctx->output(0)));
+
+ // Inject a shape, and incrementally propagate it to the dequeue op.
+ ctx = m.GetContext(queue);
+ shape_inference::ShapeHandle shp = ctx->MakeShape({3, 7});
+ ctx->set_output_handle_shape(0, shp);
+ ctx->set_output_handle_dtype(0, DT_FLOAT);
+
+ bool refined = false;
+ TF_ASSERT_OK(m.UpdateNode(dequeue, &refined));
+ EXPECT_TRUE(refined);
+ ctx = m.GetContext(dequeue);
+ EXPECT_EQ("[3,7]", ctx->DebugString(ctx->output(0)));
+}
+
} // namespace
} // namespace tensorflow