diff options
Diffstat (limited to 'tensorflow/core/ops/list_ops.cc')
-rw-r--r-- | tensorflow/core/ops/list_ops.cc | 51 |
1 files changed, 49 insertions, 2 deletions
diff --git a/tensorflow/core/ops/list_ops.cc b/tensorflow/core/ops/list_ops.cc index b9f94ba1c5..7d79df9c1c 100644 --- a/tensorflow/core/ops/list_ops.cc +++ b/tensorflow/core/ops/list_ops.cc @@ -210,7 +210,8 @@ REGISTER_OP("TensorListFromTensor") shape_inference::ShapeHandle o; TF_RETURN_IF_ERROR(c->Subshape(s, 1, &o)); shape_inference::ShapeHandle element_shape; - TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &element_shape)); + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape( + 1, &element_shape)); TF_RETURN_IF_ERROR(c->Merge(o, element_shape, &o)); c->set_output_handle_shapes_and_types( 0, std::vector<shape_inference::ShapeAndType>{{element_shape, t}}); @@ -240,7 +241,8 @@ REGISTER_OP("TensorListReserve") .SetShapeFn([](shape_inference::InferenceContext* c) { c->set_output(0, c->Scalar()); shape_inference::ShapeHandle s; - TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); + TF_RETURN_IF_ERROR( + c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(0, &s)); DataType t; TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &t)); c->set_output_handle_shapes_and_types( @@ -295,6 +297,51 @@ REGISTER_OP("TensorListSetItem") return Status::OK(); }); +REGISTER_OP("TensorListGather") + .Input("input_handle: variant") + .Input("indices: int32") + .Output("values: element_dtype") + .Attr("element_dtype: type") + .SetShapeFn([](shape_inference::InferenceContext* c) { + DataType t; + TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &t)); + auto* handle_data = c->input_handle_shapes_and_types(0); + shape_inference::ShapeHandle element_shape = c->UnknownShape(); + if (handle_data != nullptr) { + const shape_inference::ShapeAndType& list_shape_type = + (*handle_data)[0]; + element_shape = list_shape_type.shape; + if (list_shape_type.dtype != t) { + return errors::InvalidArgument("Expected list with element dtype ", + DataTypeString(t), + " but got list with element dtype ", + DataTypeString(list_shape_type.dtype)); + } + } + shape_inference::ShapeHandle out; + TF_RETURN_IF_ERROR(c->Concatenate(c->input(1), element_shape, &out)); + c->set_output(0, out); + return Status::OK(); + }); + +REGISTER_OP("TensorListScatter") + .Input("tensor: element_dtype") + .Input("indices: int32") + .Input("element_shape: shape_type") + .Output("output_handle: variant") + .Attr("element_dtype: type") + .Attr("shape_type: {int32, int64}") + .SetShapeFn([](shape_inference::InferenceContext* c) { + DataType t; + TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &t)); + shape_inference::ShapeHandle s; + TF_RETURN_IF_ERROR( + c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(2, &s)); + c->set_output_handle_shapes_and_types(0, {{s, t}}); + c->set_output(0, c->Scalar()); + return Status::OK(); + }); + REGISTER_OP("TensorListConcatLists") .Input("input_a: variant") .Input("input_b: variant") |