diff options
Diffstat (limited to 'tensorflow/core/ops/data_flow_ops.cc')
-rw-r--r-- | tensorflow/core/ops/data_flow_ops.cc | 12 |
1 files changed, 11 insertions, 1 deletions
diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc index f82e9d1eb7..f35a1bb648 100644 --- a/tensorflow/core/ops/data_flow_ops.cc +++ b/tensorflow/core/ops/data_flow_ops.cc @@ -623,7 +623,17 @@ REGISTER_OP("QueueDequeueV2") .Output("components: component_types") .Attr("component_types: list(type) >= 1") .Attr("timeout_ms: int = -1") - .SetShapeFn(shape_inference::UnknownShape) + .SetShapeFn([](InferenceContext* c) { + if (c->num_outputs() == 1) { + c->set_output(0, c->input_handle_shape(0)); + } else { + // TODO(vrv): handle the case of multiple outputs. + for (int i = 0; i < c->num_outputs(); ++i) { + c->set_output(i, c->UnknownShape()); + } + } + return Status::OK(); + }) .Doc(R"doc( Dequeues a tuple of one or more tensors from the given queue. |