aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/list_ops.cc
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2018-04-18 18:49:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-18 18:51:49 -0700
commita699d69c621fde118d4c89ba94658a9d7f91faac (patch)
tree4e8bdb669299977056ff21268e9edf7e68459e89 /tensorflow/core/ops/list_ops.cc
parentf1fb08bbb70047af0c86cc440ccc0581e64fd85f (diff)
[TF TensorLists] Add TensorListConcatLists
TensorListConcat concatenates two TensorLists' entries (supports non-scalar Tensors containing TensorLists). PiperOrigin-RevId: 193451787
Diffstat (limited to 'tensorflow/core/ops/list_ops.cc')
-rw-r--r--tensorflow/core/ops/list_ops.cc41
1 files changed, 41 insertions, 0 deletions
diff --git a/tensorflow/core/ops/list_ops.cc b/tensorflow/core/ops/list_ops.cc
index 7af70110b7..b9f94ba1c5 100644
--- a/tensorflow/core/ops/list_ops.cc
+++ b/tensorflow/core/ops/list_ops.cc
@@ -295,5 +295,46 @@ REGISTER_OP("TensorListSetItem")
return Status::OK();
});
+REGISTER_OP("TensorListConcatLists")
+ .Input("input_a: variant")
+ .Input("input_b: variant")
+ .Attr("element_dtype: type")
+ .Output("output: variant")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ auto input_a = c->input(0);
+ auto input_b = c->input(1);
+ TF_RETURN_IF_ERROR(c->Merge(input_a, input_b, &input_a));
+ c->set_output(0, input_a);
+
+ DataType t;
+ TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &t));
+
+ auto* handle_data_a = c->input_handle_shapes_and_types(0);
+ auto* handle_data_b = c->input_handle_shapes_and_types(1);
+ if (handle_data_a == nullptr && handle_data_b == nullptr) {
+ c->set_output_handle_shapes_and_types(0, {{c->UnknownShape(), t}});
+ return Status::OK();
+ }
+ shape_inference::ShapeAndType list_shape_type_a =
+ (handle_data_a) ? handle_data_a->at(0) : handle_data_b->at(0);
+ const shape_inference::ShapeAndType& list_shape_type_b =
+ (handle_data_b) ? handle_data_b->at(0) : handle_data_a->at(0);
+ if (list_shape_type_a.dtype != t) {
+ return errors::InvalidArgument("input_a.type != element_dtype: ",
+ DataTypeString(list_shape_type_a.dtype),
+ " vs. ", DataTypeString(t));
+ }
+ if (list_shape_type_b.dtype != t) {
+ return errors::InvalidArgument("input_b.type != element_dtype: ",
+ DataTypeString(list_shape_type_b.dtype),
+ " vs. ", DataTypeString(t));
+ }
+ TF_RETURN_IF_ERROR(c->Merge(list_shape_type_a.shape,
+ list_shape_type_b.shape,
+ &list_shape_type_a.shape));
+ c->set_output_handle_shapes_and_types(0, {list_shape_type_a});
+ return Status::OK();
+ });
+
} // namespace
} // namespace tensorflow