aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/list_kernels.h
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-08-30 10:14:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-30 10:23:30 -0700
commitee89fccfd1db25563dadd0e3b4336612d7c52e0a (patch)
tree83e0034a75e66a60a97d675238e93ce2278d3c7a /tensorflow/core/kernels/list_kernels.h
parented248787e7045cc484fd7cff3d2447c5c776aa84 (diff)
TensorListScatter and TensorListGather
This closes one API hole between TensorList and TensorArray PiperOrigin-RevId: 210932049
Diffstat (limited to 'tensorflow/core/kernels/list_kernels.h')
-rw-r--r--tensorflow/core/kernels/list_kernels.h121
1 files changed, 121 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/list_kernels.h b/tensorflow/core/kernels/list_kernels.h
index b3f74c060b..066a1d603b 100644
--- a/tensorflow/core/kernels/list_kernels.h
+++ b/tensorflow/core/kernels/list_kernels.h
@@ -134,6 +134,74 @@ class TensorListStack : public OpKernel {
};
template <typename Device, typename T>
+class TensorListGather : public OpKernel {
+ public:
+ typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
+ ConstMatrixVector;
+ explicit TensorListGather(OpKernelConstruction* c) : OpKernel(c) {
+ OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
+ }
+
+ void Compute(OpKernelContext* c) override {
+ const TensorList* l = c->input(0).scalar<Variant>()().get<TensorList>();
+ OP_REQUIRES(c, l != nullptr,
+ errors::InvalidArgument(
+ "Input handle is not a list. Saw: '",
+ c->input(0).scalar<Variant>()().DebugString(), "'"));
+ OP_REQUIRES(c, element_dtype_ == l->element_dtype,
+ errors::InvalidArgument("Invalid data types; op elements ",
+ DataTypeString(element_dtype_),
+ " but list elements ",
+ DataTypeString(l->element_dtype)));
+ OP_REQUIRES(c, l->element_shape.IsFullyDefined(),
+ errors::InvalidArgument("Tried to stack elements from a list "
+ "with non-fully-defined shape: ",
+ l->element_shape.DebugString()));
+ Tensor indices = c->input(1);
+ TensorShape resulting_shape;
+ resulting_shape.AddDim(indices.NumElements());
+ for (TensorShapeDim s : l->element_shape) {
+ resulting_shape.AddDim(s.size);
+ }
+ Tensor* output;
+ OP_REQUIRES_OK(c, c->allocate_output(0, resulting_shape, &output));
+ if (output->NumElements() == 0) {
+ return;
+ }
+
+ ConstMatrixVector inputs_flat;
+ inputs_flat.reserve(l->tensors.size());
+ for (int index = 0; index < indices.NumElements(); ++index) {
+ const int i = indices.flat<int32>()(index);
+ OP_REQUIRES(
+ c, i < l->tensors.size(),
+ errors::InvalidArgument("Index ", i, " out o range; list only has ",
+ l->tensors.size(), " elements."));
+ const Tensor& t = l->tensors[i];
+ OP_REQUIRES(c, l->element_shape.IsCompatibleWith(t.shape()),
+ errors::InvalidArgument(
+ "Tensor with invalid shape in list. List element shape: ",
+ l->element_shape.DebugString(),
+ " and tensor shape: ", t.shape().DebugString()));
+ inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
+ t.shaped<T, 2>({1, t.NumElements()})));
+ }
+ auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
+
+#if GOOGLE_CUDA
+ if (std::is_same<Device, Eigen::GpuDevice>::value) {
+ ConcatGPU<T>(c, inputs_flat, output, &output_flat);
+ return;
+ }
+#endif // GOOGLE_CUDA
+ ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
+ }
+
+ private:
+ DataType element_dtype_;
+};
+
+template <typename Device, typename T>
class TensorListFromTensor : public OpKernel {
public:
TensorListFromTensor(OpKernelConstruction* c) : OpKernel(c) {}
@@ -178,6 +246,59 @@ class TensorListFromTensor : public OpKernel {
}
};
+template <typename Device, typename T>
+class TensorListScatter : public OpKernel {
+ public:
+ TensorListScatter(OpKernelConstruction* c) : OpKernel(c) {}
+
+ void Compute(OpKernelContext* c) override {
+ Tensor* output_tensor;
+ AllocatorAttributes attr;
+ attr.set_on_host(true);
+ OP_REQUIRES_OK(c, c->allocate_output(0, {}, &output_tensor, attr));
+ Tensor indices = c->input(1);
+ PartialTensorShape element_shape;
+ OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(2), &element_shape));
+ TensorList output_list;
+ const Tensor& t = c->input(0);
+ output_list.element_dtype = t.dtype();
+ OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(t.shape()),
+ errors::InvalidArgument(
+ "Tensor must be at least a vector, but saw shape: ",
+ t.shape().DebugString()));
+ TensorShape output_shape(t.shape());
+ output_shape.RemoveDim(0);
+ OP_REQUIRES(c, element_shape.IsCompatibleWith(output_shape),
+ errors::InvalidArgument(
+ "Specified a list with shape ", element_shape.DebugString(),
+ " from a tensor with shape ", output_shape.DebugString()));
+ output_list.element_shape = element_shape;
+ output_list.tensors.reserve(indices.NumElements());
+ for (int index = 0; index < indices.NumElements(); ++index) {
+ const int i = indices.flat<int32>()(index);
+ OP_REQUIRES(c, i < t.shape().dim_size(0),
+ errors::InvalidArgument("Trying to scatter index ", i,
+ " from tensor with ",
+ t.shape().dim_size(0), " rows."));
+ Tensor tmp = t.Slice(i, i + 1);
+ TensorShape tmp_shape = tmp.shape();
+ tmp_shape.RemoveDim(0);
+ OP_REQUIRES(c, tmp.CopyFrom(tmp, tmp_shape),
+ errors::Unknown("Unexpected shape error."));
+ // TODO(apassos) maybe not always align; but weird compiler bugs seem to
+ // prevent this.
+ Tensor aligned;
+ OP_REQUIRES_OK(c, c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned));
+ // TODO(apassos) do all slices in a single kernel invocation instead of
+ // many small ondes.
+ aligned.flat<T>().device(c->eigen_device<Device>()) =
+ tmp.unaligned_flat<T>();
+ output_list.tensors.push_back(aligned);
+ }
+ output_tensor->scalar<Variant>()() = std::move(output_list);
+ }
+};
+
template <typename Device>
Status TensorListBinaryAdd(OpKernelContext* c, const TensorList& a,
const TensorList& b, TensorList* out) {