diff options
Diffstat (limited to 'tensorflow/core/common_runtime/shape_refiner_test.cc')
-rw-r--r-- | tensorflow/core/common_runtime/shape_refiner_test.cc | 33 |
1 files changed, 33 insertions, 0 deletions
diff --git a/tensorflow/core/common_runtime/shape_refiner_test.cc b/tensorflow/core/common_runtime/shape_refiner_test.cc index d7e7c3b5ad..b8df6dd4f6 100644 --- a/tensorflow/core/common_runtime/shape_refiner_test.cc +++ b/tensorflow/core/common_runtime/shape_refiner_test.cc @@ -768,5 +768,38 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_ConcatInvalidDimValue) { m.AddNode(result).error_message()); } +TEST(ShapeRefinerTest, IncrementalUpdates) { + Scope root = Scope::NewRootScope(); + Graph* g = root.graph(); + Node* queue; + TF_CHECK_OK(NodeBuilder("queue", "FIFOQueueV2") + .Attr("component_types", {DT_FLOAT}) + .Finalize(g, &queue)); + Node* dequeue; + TF_CHECK_OK(NodeBuilder("dequeue", "QueueDequeueV2") + .Attr("component_types", {DT_FLOAT}) + .Input(queue) + .Finalize(g, &dequeue)); + ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global()); + TF_ASSERT_OK(m.AddNode(queue)); + TF_ASSERT_OK(m.AddNode(dequeue)); + + // At this point, the shapes of the dequeued tensor are unknown. + shape_inference::InferenceContext* ctx = m.GetContext(dequeue); + EXPECT_EQ("?", ctx->DebugString(ctx->output(0))); + + // Inject a shape, and incrementally propagate it to the dequeue op. + ctx = m.GetContext(queue); + shape_inference::ShapeHandle shp = ctx->MakeShape({3, 7}); + ctx->set_output_handle_shape(0, shp); + ctx->set_output_handle_dtype(0, DT_FLOAT); + + bool refined = false; + TF_ASSERT_OK(m.UpdateNode(dequeue, &refined)); + EXPECT_TRUE(refined); + ctx = m.GetContext(dequeue); + EXPECT_EQ("[3,7]", ctx->DebugString(ctx->output(0))); +} + } // namespace } // namespace tensorflow |