aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/data_flow_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops/data_flow_ops.cc')
-rw-r--r--tensorflow/core/ops/data_flow_ops.cc12
1 files changed, 11 insertions, 1 deletions
diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc
index f82e9d1eb7..f35a1bb648 100644
--- a/tensorflow/core/ops/data_flow_ops.cc
+++ b/tensorflow/core/ops/data_flow_ops.cc
@@ -623,7 +623,17 @@ REGISTER_OP("QueueDequeueV2")
.Output("components: component_types")
.Attr("component_types: list(type) >= 1")
.Attr("timeout_ms: int = -1")
- .SetShapeFn(shape_inference::UnknownShape)
+ .SetShapeFn([](InferenceContext* c) {
+ if (c->num_outputs() == 1) {
+ c->set_output(0, c->input_handle_shape(0));
+ } else {
+ // TODO(vrv): handle the case of multiple outputs.
+ for (int i = 0; i < c->num_outputs(); ++i) {
+ c->set_output(i, c->UnknownShape());
+ }
+ }
+ return Status::OK();
+ })
.Doc(R"doc(
Dequeues a tuple of one or more tensors from the given queue.