diff options
author | 2018-04-18 18:49:02 -0700 | |
---|---|---|
committer | 2018-04-18 18:51:49 -0700 | |
commit | a699d69c621fde118d4c89ba94658a9d7f91faac (patch) | |
tree | 4e8bdb669299977056ff21268e9edf7e68459e89 /tensorflow/core/ops/list_ops.cc | |
parent | f1fb08bbb70047af0c86cc440ccc0581e64fd85f (diff) |
[TF TensorLists] Add TensorListConcatLists
TensorListConcat concatenates two TensorLists' entries (supports non-scalar
Tensors containing TensorLists).
PiperOrigin-RevId: 193451787
Diffstat (limited to 'tensorflow/core/ops/list_ops.cc')
-rw-r--r-- | tensorflow/core/ops/list_ops.cc | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/tensorflow/core/ops/list_ops.cc b/tensorflow/core/ops/list_ops.cc index 7af70110b7..b9f94ba1c5 100644 --- a/tensorflow/core/ops/list_ops.cc +++ b/tensorflow/core/ops/list_ops.cc @@ -295,5 +295,46 @@ REGISTER_OP("TensorListSetItem") return Status::OK(); }); +REGISTER_OP("TensorListConcatLists") + .Input("input_a: variant") + .Input("input_b: variant") + .Attr("element_dtype: type") + .Output("output: variant") + .SetShapeFn([](shape_inference::InferenceContext* c) { + auto input_a = c->input(0); + auto input_b = c->input(1); + TF_RETURN_IF_ERROR(c->Merge(input_a, input_b, &input_a)); + c->set_output(0, input_a); + + DataType t; + TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &t)); + + auto* handle_data_a = c->input_handle_shapes_and_types(0); + auto* handle_data_b = c->input_handle_shapes_and_types(1); + if (handle_data_a == nullptr && handle_data_b == nullptr) { + c->set_output_handle_shapes_and_types(0, {{c->UnknownShape(), t}}); + return Status::OK(); + } + shape_inference::ShapeAndType list_shape_type_a = + (handle_data_a) ? handle_data_a->at(0) : handle_data_b->at(0); + const shape_inference::ShapeAndType& list_shape_type_b = + (handle_data_b) ? handle_data_b->at(0) : handle_data_a->at(0); + if (list_shape_type_a.dtype != t) { + return errors::InvalidArgument("input_a.type != element_dtype: ", + DataTypeString(list_shape_type_a.dtype), + " vs. ", DataTypeString(t)); + } + if (list_shape_type_b.dtype != t) { + return errors::InvalidArgument("input_b.type != element_dtype: ", + DataTypeString(list_shape_type_b.dtype), + " vs. ", DataTypeString(t)); + } + TF_RETURN_IF_ERROR(c->Merge(list_shape_type_a.shape, + list_shape_type_b.shape, + &list_shape_type_a.shape)); + c->set_output_handle_shapes_and_types(0, {list_shape_type_a}); + return Status::OK(); + }); + } // namespace } // namespace tensorflow |