diff options
author | 2018-08-30 10:14:22 -0700 | |
---|---|---|
committer | 2018-08-30 10:23:30 -0700 | |
commit | ee89fccfd1db25563dadd0e3b4336612d7c52e0a (patch) | |
tree | 83e0034a75e66a60a97d675238e93ce2278d3c7a /tensorflow/core/kernels/list_kernels.h | |
parent | ed248787e7045cc484fd7cff3d2447c5c776aa84 (diff) |
TensorListScatter and TensorListGather
This closes one API hole between TensorList and TensorArray
PiperOrigin-RevId: 210932049
Diffstat (limited to 'tensorflow/core/kernels/list_kernels.h')
-rw-r--r-- | tensorflow/core/kernels/list_kernels.h | 121 |
1 files changed, 121 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/list_kernels.h b/tensorflow/core/kernels/list_kernels.h index b3f74c060b..066a1d603b 100644 --- a/tensorflow/core/kernels/list_kernels.h +++ b/tensorflow/core/kernels/list_kernels.h @@ -134,6 +134,74 @@ class TensorListStack : public OpKernel { }; template <typename Device, typename T> +class TensorListGather : public OpKernel { + public: + typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>> + ConstMatrixVector; + explicit TensorListGather(OpKernelConstruction* c) : OpKernel(c) { + OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_)); + } + + void Compute(OpKernelContext* c) override { + const TensorList* l = c->input(0).scalar<Variant>()().get<TensorList>(); + OP_REQUIRES(c, l != nullptr, + errors::InvalidArgument( + "Input handle is not a list. Saw: '", + c->input(0).scalar<Variant>()().DebugString(), "'")); + OP_REQUIRES(c, element_dtype_ == l->element_dtype, + errors::InvalidArgument("Invalid data types; op elements ", + DataTypeString(element_dtype_), + " but list elements ", + DataTypeString(l->element_dtype))); + OP_REQUIRES(c, l->element_shape.IsFullyDefined(), + errors::InvalidArgument("Tried to stack elements from a list " + "with non-fully-defined shape: ", + l->element_shape.DebugString())); + Tensor indices = c->input(1); + TensorShape resulting_shape; + resulting_shape.AddDim(indices.NumElements()); + for (TensorShapeDim s : l->element_shape) { + resulting_shape.AddDim(s.size); + } + Tensor* output; + OP_REQUIRES_OK(c, c->allocate_output(0, resulting_shape, &output)); + if (output->NumElements() == 0) { + return; + } + + ConstMatrixVector inputs_flat; + inputs_flat.reserve(l->tensors.size()); + for (int index = 0; index < indices.NumElements(); ++index) { + const int i = indices.flat<int32>()(index); + OP_REQUIRES( + c, i < l->tensors.size(), + errors::InvalidArgument("Index ", i, " out o range; list only has ", + l->tensors.size(), " elements.")); + const Tensor& t = l->tensors[i]; + OP_REQUIRES(c, l->element_shape.IsCompatibleWith(t.shape()), + errors::InvalidArgument( + "Tensor with invalid shape in list. List element shape: ", + l->element_shape.DebugString(), + " and tensor shape: ", t.shape().DebugString())); + inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix( + t.shaped<T, 2>({1, t.NumElements()}))); + } + auto output_flat = output->shaped<T, 2>({1, output->NumElements()}); + +#if GOOGLE_CUDA + if (std::is_same<Device, Eigen::GpuDevice>::value) { + ConcatGPU<T>(c, inputs_flat, output, &output_flat); + return; + } +#endif // GOOGLE_CUDA + ConcatCPU<T>(c->device(), inputs_flat, &output_flat); + } + + private: + DataType element_dtype_; +}; + +template <typename Device, typename T> class TensorListFromTensor : public OpKernel { public: TensorListFromTensor(OpKernelConstruction* c) : OpKernel(c) {} @@ -178,6 +246,59 @@ class TensorListFromTensor : public OpKernel { } }; +template <typename Device, typename T> +class TensorListScatter : public OpKernel { + public: + TensorListScatter(OpKernelConstruction* c) : OpKernel(c) {} + + void Compute(OpKernelContext* c) override { + Tensor* output_tensor; + AllocatorAttributes attr; + attr.set_on_host(true); + OP_REQUIRES_OK(c, c->allocate_output(0, {}, &output_tensor, attr)); + Tensor indices = c->input(1); + PartialTensorShape element_shape; + OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(2), &element_shape)); + TensorList output_list; + const Tensor& t = c->input(0); + output_list.element_dtype = t.dtype(); + OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(t.shape()), + errors::InvalidArgument( + "Tensor must be at least a vector, but saw shape: ", + t.shape().DebugString())); + TensorShape output_shape(t.shape()); + output_shape.RemoveDim(0); + OP_REQUIRES(c, element_shape.IsCompatibleWith(output_shape), + errors::InvalidArgument( + "Specified a list with shape ", element_shape.DebugString(), + " from a tensor with shape ", output_shape.DebugString())); + output_list.element_shape = element_shape; + output_list.tensors.reserve(indices.NumElements()); + for (int index = 0; index < indices.NumElements(); ++index) { + const int i = indices.flat<int32>()(index); + OP_REQUIRES(c, i < t.shape().dim_size(0), + errors::InvalidArgument("Trying to scatter index ", i, + " from tensor with ", + t.shape().dim_size(0), " rows.")); + Tensor tmp = t.Slice(i, i + 1); + TensorShape tmp_shape = tmp.shape(); + tmp_shape.RemoveDim(0); + OP_REQUIRES(c, tmp.CopyFrom(tmp, tmp_shape), + errors::Unknown("Unexpected shape error.")); + // TODO(apassos) maybe not always align; but weird compiler bugs seem to + // prevent this. + Tensor aligned; + OP_REQUIRES_OK(c, c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned)); + // TODO(apassos) do all slices in a single kernel invocation instead of + // many small ondes. + aligned.flat<T>().device(c->eigen_device<Device>()) = + tmp.unaligned_flat<T>(); + output_list.tensors.push_back(aligned); + } + output_tensor->scalar<Variant>()() = std::move(output_list); + } +}; + template <typename Device> Status TensorListBinaryAdd(OpKernelContext* c, const TensorList& a, const TensorList& b, TensorList* out) { |