aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/list_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops/list_ops.cc')
-rw-r--r--tensorflow/core/ops/list_ops.cc51
1 files changed, 49 insertions, 2 deletions
diff --git a/tensorflow/core/ops/list_ops.cc b/tensorflow/core/ops/list_ops.cc
index b9f94ba1c5..7d79df9c1c 100644
--- a/tensorflow/core/ops/list_ops.cc
+++ b/tensorflow/core/ops/list_ops.cc
@@ -210,7 +210,8 @@ REGISTER_OP("TensorListFromTensor")
shape_inference::ShapeHandle o;
TF_RETURN_IF_ERROR(c->Subshape(s, 1, &o));
shape_inference::ShapeHandle element_shape;
- TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &element_shape));
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
+ 1, &element_shape));
TF_RETURN_IF_ERROR(c->Merge(o, element_shape, &o));
c->set_output_handle_shapes_and_types(
0, std::vector<shape_inference::ShapeAndType>{{element_shape, t}});
@@ -240,7 +241,8 @@ REGISTER_OP("TensorListReserve")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->Scalar());
shape_inference::ShapeHandle s;
- TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
+ TF_RETURN_IF_ERROR(
+ c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(0, &s));
DataType t;
TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &t));
c->set_output_handle_shapes_and_types(
@@ -295,6 +297,51 @@ REGISTER_OP("TensorListSetItem")
return Status::OK();
});
+REGISTER_OP("TensorListGather")
+ .Input("input_handle: variant")
+ .Input("indices: int32")
+ .Output("values: element_dtype")
+ .Attr("element_dtype: type")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ DataType t;
+ TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &t));
+ auto* handle_data = c->input_handle_shapes_and_types(0);
+ shape_inference::ShapeHandle element_shape = c->UnknownShape();
+ if (handle_data != nullptr) {
+ const shape_inference::ShapeAndType& list_shape_type =
+ (*handle_data)[0];
+ element_shape = list_shape_type.shape;
+ if (list_shape_type.dtype != t) {
+ return errors::InvalidArgument("Expected list with element dtype ",
+ DataTypeString(t),
+ " but got list with element dtype ",
+ DataTypeString(list_shape_type.dtype));
+ }
+ }
+ shape_inference::ShapeHandle out;
+ TF_RETURN_IF_ERROR(c->Concatenate(c->input(1), element_shape, &out));
+ c->set_output(0, out);
+ return Status::OK();
+ });
+
+REGISTER_OP("TensorListScatter")
+ .Input("tensor: element_dtype")
+ .Input("indices: int32")
+ .Input("element_shape: shape_type")
+ .Output("output_handle: variant")
+ .Attr("element_dtype: type")
+ .Attr("shape_type: {int32, int64}")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ DataType t;
+ TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &t));
+ shape_inference::ShapeHandle s;
+ TF_RETURN_IF_ERROR(
+ c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(2, &s));
+ c->set_output_handle_shapes_and_types(0, {{s, t}});
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+ });
+
REGISTER_OP("TensorListConcatLists")
.Input("input_a: variant")
.Input("input_b: variant")