aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/api_def/base_api/api_def_TensorListGather.pbtxt12
-rw-r--r--tensorflow/core/api_def/base_api/api_def_TensorListScatter.pbtxt14
-rw-r--r--tensorflow/core/kernels/list_kernels.cc12
-rw-r--r--tensorflow/core/kernels/list_kernels.cu.cc15
-rw-r--r--tensorflow/core/kernels/list_kernels.h121
-rw-r--r--tensorflow/core/ops/list_ops.cc51
6 files changed, 219 insertions, 6 deletions
diff --git a/tensorflow/core/api_def/base_api/api_def_TensorListGather.pbtxt b/tensorflow/core/api_def/base_api/api_def_TensorListGather.pbtxt
new file mode 100644
index 0000000000..3022fccb1e
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_TensorListGather.pbtxt
@@ -0,0 +1,12 @@
+op {
+ graph_op_name: "TensorListGather"
+ summary: "Creates a Tensor by indexing into the TensorList."
+ description: <<END
+Each row in the produced Tensor corresponds to the element in the TensorList
+specified by the given index (see `tf.gather`).
+
+input_handle: The input tensor list.
+indices: The indices used to index into the list.
+values: The tensor.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_TensorListScatter.pbtxt b/tensorflow/core/api_def/base_api/api_def_TensorListScatter.pbtxt
new file mode 100644
index 0000000000..35194b353e
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_TensorListScatter.pbtxt
@@ -0,0 +1,14 @@
+op {
+ graph_op_name: "TensorListScatter"
+ summary: "Creates a TensorList by indexing into a Tensor."
+ description: <<END
+Each member of the TensorList corresponds to one row of the input tensor,
+specified by the given index (see `tf.gather`).
+
+tensor: The input tensor.
+indices: The indices used to index into the list.
+element_shape: The shape of the elements in the list (can be less specified than
+ the shape of the tensor).
+output_handle: The TensorList.
+END
+}
diff --git a/tensorflow/core/kernels/list_kernels.cc b/tensorflow/core/kernels/list_kernels.cc
index 84fa63fc00..bca1cff41c 100644
--- a/tensorflow/core/kernels/list_kernels.cc
+++ b/tensorflow/core/kernels/list_kernels.cc
@@ -588,7 +588,11 @@ REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_CPU(bfloat16);
REGISTER_KERNEL_BUILDER(Name("TensorListStack") \
.TypeConstraint<T>("element_dtype") \
.Device(DEVICE_CPU), \
- TensorListStack<CPUDevice, T>)
+ TensorListStack<CPUDevice, T>) \
+ REGISTER_KERNEL_BUILDER(Name("TensorListGather") \
+ .TypeConstraint<T>("element_dtype") \
+ .Device(DEVICE_CPU), \
+ TensorListGather<CPUDevice, T>)
TF_CALL_POD_STRING_TYPES(REGISTER_TENSOR_LIST_STACK_CPU);
REGISTER_TENSOR_LIST_STACK_CPU(quint8);
@@ -604,7 +608,11 @@ REGISTER_TENSOR_LIST_STACK_CPU(bfloat16);
REGISTER_KERNEL_BUILDER(Name("TensorListFromTensor") \
.TypeConstraint<T>("element_dtype") \
.Device(DEVICE_CPU), \
- TensorListFromTensor<CPUDevice, T>)
+ TensorListFromTensor<CPUDevice, T>) \
+ REGISTER_KERNEL_BUILDER(Name("TensorListScatter") \
+ .TypeConstraint<T>("element_dtype") \
+ .Device(DEVICE_CPU), \
+ TensorListScatter<CPUDevice, T>)
TF_CALL_POD_STRING_TYPES(REGISTER_TENSOR_LIST_FROM_TENSOR_CPU);
REGISTER_TENSOR_LIST_FROM_TENSOR_CPU(quint8);
diff --git a/tensorflow/core/kernels/list_kernels.cu.cc b/tensorflow/core/kernels/list_kernels.cu.cc
index 0ea9362cbe..c591226b76 100644
--- a/tensorflow/core/kernels/list_kernels.cu.cc
+++ b/tensorflow/core/kernels/list_kernels.cu.cc
@@ -40,7 +40,12 @@ typedef Eigen::GpuDevice GPUDevice;
REGISTER_KERNEL_BUILDER(Name("TensorListStack") \
.TypeConstraint<T>("element_dtype") \
.Device(DEVICE_GPU), \
- TensorListStack<GPUDevice, T>)
+ TensorListStack<GPUDevice, T>) \
+ REGISTER_KERNEL_BUILDER(Name("TensorListGather") \
+ .TypeConstraint<T>("element_dtype") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("indices"), \
+ TensorListGather<GPUDevice, T>)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_TENSOR_LIST_STACK_GPU);
REGISTER_TENSOR_LIST_STACK_GPU(bfloat16);
@@ -71,7 +76,13 @@ REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_GPU(bool);
.TypeConstraint<T>("element_dtype") \
.Device(DEVICE_GPU) \
.HostMemory("element_shape"), \
- TensorListFromTensor<GPUDevice, T>)
+ TensorListFromTensor<GPUDevice, T>) \
+ REGISTER_KERNEL_BUILDER(Name("TensorListScatter") \
+ .TypeConstraint<T>("element_dtype") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("element_shape") \
+ .HostMemory("indices"), \
+ TensorListScatter<GPUDevice, T>)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_TENSOR_LIST_FROM_TENSOR_GPU);
REGISTER_TENSOR_LIST_FROM_TENSOR_GPU(bfloat16);
diff --git a/tensorflow/core/kernels/list_kernels.h b/tensorflow/core/kernels/list_kernels.h
index b3f74c060b..066a1d603b 100644
--- a/tensorflow/core/kernels/list_kernels.h
+++ b/tensorflow/core/kernels/list_kernels.h
@@ -134,6 +134,74 @@ class TensorListStack : public OpKernel {
};
template <typename Device, typename T>
+class TensorListGather : public OpKernel {
+ public:
+ typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
+ ConstMatrixVector;
+ explicit TensorListGather(OpKernelConstruction* c) : OpKernel(c) {
+ OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
+ }
+
+ void Compute(OpKernelContext* c) override {
+ const TensorList* l = c->input(0).scalar<Variant>()().get<TensorList>();
+ OP_REQUIRES(c, l != nullptr,
+ errors::InvalidArgument(
+ "Input handle is not a list. Saw: '",
+ c->input(0).scalar<Variant>()().DebugString(), "'"));
+ OP_REQUIRES(c, element_dtype_ == l->element_dtype,
+ errors::InvalidArgument("Invalid data types; op elements ",
+ DataTypeString(element_dtype_),
+ " but list elements ",
+ DataTypeString(l->element_dtype)));
+ OP_REQUIRES(c, l->element_shape.IsFullyDefined(),
+ errors::InvalidArgument("Tried to stack elements from a list "
+ "with non-fully-defined shape: ",
+ l->element_shape.DebugString()));
+ Tensor indices = c->input(1);
+ TensorShape resulting_shape;
+ resulting_shape.AddDim(indices.NumElements());
+ for (TensorShapeDim s : l->element_shape) {
+ resulting_shape.AddDim(s.size);
+ }
+ Tensor* output;
+ OP_REQUIRES_OK(c, c->allocate_output(0, resulting_shape, &output));
+ if (output->NumElements() == 0) {
+ return;
+ }
+
+ ConstMatrixVector inputs_flat;
+ inputs_flat.reserve(l->tensors.size());
+ for (int index = 0; index < indices.NumElements(); ++index) {
+ const int i = indices.flat<int32>()(index);
+ OP_REQUIRES(
+ c, i < l->tensors.size(),
+ errors::InvalidArgument("Index ", i, " out o range; list only has ",
+ l->tensors.size(), " elements."));
+ const Tensor& t = l->tensors[i];
+ OP_REQUIRES(c, l->element_shape.IsCompatibleWith(t.shape()),
+ errors::InvalidArgument(
+ "Tensor with invalid shape in list. List element shape: ",
+ l->element_shape.DebugString(),
+ " and tensor shape: ", t.shape().DebugString()));
+ inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
+ t.shaped<T, 2>({1, t.NumElements()})));
+ }
+ auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
+
+#if GOOGLE_CUDA
+ if (std::is_same<Device, Eigen::GpuDevice>::value) {
+ ConcatGPU<T>(c, inputs_flat, output, &output_flat);
+ return;
+ }
+#endif // GOOGLE_CUDA
+ ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
+ }
+
+ private:
+ DataType element_dtype_;
+};
+
+template <typename Device, typename T>
class TensorListFromTensor : public OpKernel {
public:
TensorListFromTensor(OpKernelConstruction* c) : OpKernel(c) {}
@@ -178,6 +246,59 @@ class TensorListFromTensor : public OpKernel {
}
};
+template <typename Device, typename T>
+class TensorListScatter : public OpKernel {
+ public:
+ TensorListScatter(OpKernelConstruction* c) : OpKernel(c) {}
+
+ void Compute(OpKernelContext* c) override {
+ Tensor* output_tensor;
+ AllocatorAttributes attr;
+ attr.set_on_host(true);
+ OP_REQUIRES_OK(c, c->allocate_output(0, {}, &output_tensor, attr));
+ Tensor indices = c->input(1);
+ PartialTensorShape element_shape;
+ OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(2), &element_shape));
+ TensorList output_list;
+ const Tensor& t = c->input(0);
+ output_list.element_dtype = t.dtype();
+ OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(t.shape()),
+ errors::InvalidArgument(
+ "Tensor must be at least a vector, but saw shape: ",
+ t.shape().DebugString()));
+ TensorShape output_shape(t.shape());
+ output_shape.RemoveDim(0);
+ OP_REQUIRES(c, element_shape.IsCompatibleWith(output_shape),
+ errors::InvalidArgument(
+ "Specified a list with shape ", element_shape.DebugString(),
+ " from a tensor with shape ", output_shape.DebugString()));
+ output_list.element_shape = element_shape;
+ output_list.tensors.reserve(indices.NumElements());
+ for (int index = 0; index < indices.NumElements(); ++index) {
+ const int i = indices.flat<int32>()(index);
+ OP_REQUIRES(c, i < t.shape().dim_size(0),
+ errors::InvalidArgument("Trying to scatter index ", i,
+ " from tensor with ",
+ t.shape().dim_size(0), " rows."));
+ Tensor tmp = t.Slice(i, i + 1);
+ TensorShape tmp_shape = tmp.shape();
+ tmp_shape.RemoveDim(0);
+ OP_REQUIRES(c, tmp.CopyFrom(tmp, tmp_shape),
+ errors::Unknown("Unexpected shape error."));
+ // TODO(apassos) maybe not always align; but weird compiler bugs seem to
+ // prevent this.
+ Tensor aligned;
+ OP_REQUIRES_OK(c, c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned));
+ // TODO(apassos) do all slices in a single kernel invocation instead of
+ // many small ondes.
+ aligned.flat<T>().device(c->eigen_device<Device>()) =
+ tmp.unaligned_flat<T>();
+ output_list.tensors.push_back(aligned);
+ }
+ output_tensor->scalar<Variant>()() = std::move(output_list);
+ }
+};
+
template <typename Device>
Status TensorListBinaryAdd(OpKernelContext* c, const TensorList& a,
const TensorList& b, TensorList* out) {
diff --git a/tensorflow/core/ops/list_ops.cc b/tensorflow/core/ops/list_ops.cc
index b9f94ba1c5..7d79df9c1c 100644
--- a/tensorflow/core/ops/list_ops.cc
+++ b/tensorflow/core/ops/list_ops.cc
@@ -210,7 +210,8 @@ REGISTER_OP("TensorListFromTensor")
shape_inference::ShapeHandle o;
TF_RETURN_IF_ERROR(c->Subshape(s, 1, &o));
shape_inference::ShapeHandle element_shape;
- TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &element_shape));
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
+ 1, &element_shape));
TF_RETURN_IF_ERROR(c->Merge(o, element_shape, &o));
c->set_output_handle_shapes_and_types(
0, std::vector<shape_inference::ShapeAndType>{{element_shape, t}});
@@ -240,7 +241,8 @@ REGISTER_OP("TensorListReserve")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->Scalar());
shape_inference::ShapeHandle s;
- TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
+ TF_RETURN_IF_ERROR(
+ c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(0, &s));
DataType t;
TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &t));
c->set_output_handle_shapes_and_types(
@@ -295,6 +297,51 @@ REGISTER_OP("TensorListSetItem")
return Status::OK();
});
+REGISTER_OP("TensorListGather")
+ .Input("input_handle: variant")
+ .Input("indices: int32")
+ .Output("values: element_dtype")
+ .Attr("element_dtype: type")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ DataType t;
+ TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &t));
+ auto* handle_data = c->input_handle_shapes_and_types(0);
+ shape_inference::ShapeHandle element_shape = c->UnknownShape();
+ if (handle_data != nullptr) {
+ const shape_inference::ShapeAndType& list_shape_type =
+ (*handle_data)[0];
+ element_shape = list_shape_type.shape;
+ if (list_shape_type.dtype != t) {
+ return errors::InvalidArgument("Expected list with element dtype ",
+ DataTypeString(t),
+ " but got list with element dtype ",
+ DataTypeString(list_shape_type.dtype));
+ }
+ }
+ shape_inference::ShapeHandle out;
+ TF_RETURN_IF_ERROR(c->Concatenate(c->input(1), element_shape, &out));
+ c->set_output(0, out);
+ return Status::OK();
+ });
+
+REGISTER_OP("TensorListScatter")
+ .Input("tensor: element_dtype")
+ .Input("indices: int32")
+ .Input("element_shape: shape_type")
+ .Output("output_handle: variant")
+ .Attr("element_dtype: type")
+ .Attr("shape_type: {int32, int64}")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ DataType t;
+ TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &t));
+ shape_inference::ShapeHandle s;
+ TF_RETURN_IF_ERROR(
+ c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(2, &s));
+ c->set_output_handle_shapes_and_types(0, {{s, t}});
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+ });
+
REGISTER_OP("TensorListConcatLists")
.Input("input_a: variant")
.Input("input_b: variant")