diff options
Diffstat (limited to 'tensorflow/core')
-rw-r--r-- | tensorflow/core/api_def/base_api/api_def_TensorListGather.pbtxt | 12 | ||||
-rw-r--r-- | tensorflow/core/api_def/base_api/api_def_TensorListScatter.pbtxt | 14 | ||||
-rw-r--r-- | tensorflow/core/kernels/list_kernels.cc | 12 | ||||
-rw-r--r-- | tensorflow/core/kernels/list_kernels.cu.cc | 15 | ||||
-rw-r--r-- | tensorflow/core/kernels/list_kernels.h | 121 | ||||
-rw-r--r-- | tensorflow/core/ops/list_ops.cc | 51 |
6 files changed, 219 insertions, 6 deletions
diff --git a/tensorflow/core/api_def/base_api/api_def_TensorListGather.pbtxt b/tensorflow/core/api_def/base_api/api_def_TensorListGather.pbtxt new file mode 100644 index 0000000000..3022fccb1e --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_TensorListGather.pbtxt @@ -0,0 +1,12 @@ +op { + graph_op_name: "TensorListGather" + summary: "Creates a Tensor by indexing into the TensorList." + description: <<END +Each row in the produced Tensor corresponds to the element in the TensorList +specified by the given index (see `tf.gather`). + +input_handle: The input tensor list. +indices: The indices used to index into the list. +values: The tensor. +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_TensorListScatter.pbtxt b/tensorflow/core/api_def/base_api/api_def_TensorListScatter.pbtxt new file mode 100644 index 0000000000..35194b353e --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_TensorListScatter.pbtxt @@ -0,0 +1,14 @@ +op { + graph_op_name: "TensorListScatter" + summary: "Creates a TensorList by indexing into a Tensor." + description: <<END +Each member of the TensorList corresponds to one row of the input tensor, +specified by the given index (see `tf.gather`). + +tensor: The input tensor. +indices: The indices used to index into the list. +element_shape: The shape of the elements in the list (can be less specified than + the shape of the tensor). +output_handle: The TensorList. +END +} diff --git a/tensorflow/core/kernels/list_kernels.cc b/tensorflow/core/kernels/list_kernels.cc index 84fa63fc00..bca1cff41c 100644 --- a/tensorflow/core/kernels/list_kernels.cc +++ b/tensorflow/core/kernels/list_kernels.cc @@ -588,7 +588,11 @@ REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_CPU(bfloat16); REGISTER_KERNEL_BUILDER(Name("TensorListStack") \ .TypeConstraint<T>("element_dtype") \ .Device(DEVICE_CPU), \ - TensorListStack<CPUDevice, T>) + TensorListStack<CPUDevice, T>) \ + REGISTER_KERNEL_BUILDER(Name("TensorListGather") \ + .TypeConstraint<T>("element_dtype") \ + .Device(DEVICE_CPU), \ + TensorListGather<CPUDevice, T>) TF_CALL_POD_STRING_TYPES(REGISTER_TENSOR_LIST_STACK_CPU); REGISTER_TENSOR_LIST_STACK_CPU(quint8); @@ -604,7 +608,11 @@ REGISTER_TENSOR_LIST_STACK_CPU(bfloat16); REGISTER_KERNEL_BUILDER(Name("TensorListFromTensor") \ .TypeConstraint<T>("element_dtype") \ .Device(DEVICE_CPU), \ - TensorListFromTensor<CPUDevice, T>) + TensorListFromTensor<CPUDevice, T>) \ + REGISTER_KERNEL_BUILDER(Name("TensorListScatter") \ + .TypeConstraint<T>("element_dtype") \ + .Device(DEVICE_CPU), \ + TensorListScatter<CPUDevice, T>) TF_CALL_POD_STRING_TYPES(REGISTER_TENSOR_LIST_FROM_TENSOR_CPU); REGISTER_TENSOR_LIST_FROM_TENSOR_CPU(quint8); diff --git a/tensorflow/core/kernels/list_kernels.cu.cc b/tensorflow/core/kernels/list_kernels.cu.cc index 0ea9362cbe..c591226b76 100644 --- a/tensorflow/core/kernels/list_kernels.cu.cc +++ b/tensorflow/core/kernels/list_kernels.cu.cc @@ -40,7 +40,12 @@ typedef Eigen::GpuDevice GPUDevice; REGISTER_KERNEL_BUILDER(Name("TensorListStack") \ .TypeConstraint<T>("element_dtype") \ .Device(DEVICE_GPU), \ - TensorListStack<GPUDevice, T>) + TensorListStack<GPUDevice, T>) \ + REGISTER_KERNEL_BUILDER(Name("TensorListGather") \ + .TypeConstraint<T>("element_dtype") \ + .Device(DEVICE_GPU) \ + .HostMemory("indices"), \ + TensorListGather<GPUDevice, T>) TF_CALL_GPU_NUMBER_TYPES(REGISTER_TENSOR_LIST_STACK_GPU); REGISTER_TENSOR_LIST_STACK_GPU(bfloat16); @@ -71,7 +76,13 @@ REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_GPU(bool); .TypeConstraint<T>("element_dtype") \ .Device(DEVICE_GPU) \ .HostMemory("element_shape"), \ - TensorListFromTensor<GPUDevice, T>) + TensorListFromTensor<GPUDevice, T>) \ + REGISTER_KERNEL_BUILDER(Name("TensorListScatter") \ + .TypeConstraint<T>("element_dtype") \ + .Device(DEVICE_GPU) \ + .HostMemory("element_shape") \ + .HostMemory("indices"), \ + TensorListScatter<GPUDevice, T>) TF_CALL_GPU_NUMBER_TYPES(REGISTER_TENSOR_LIST_FROM_TENSOR_GPU); REGISTER_TENSOR_LIST_FROM_TENSOR_GPU(bfloat16); diff --git a/tensorflow/core/kernels/list_kernels.h b/tensorflow/core/kernels/list_kernels.h index b3f74c060b..066a1d603b 100644 --- a/tensorflow/core/kernels/list_kernels.h +++ b/tensorflow/core/kernels/list_kernels.h @@ -134,6 +134,74 @@ class TensorListStack : public OpKernel { }; template <typename Device, typename T> +class TensorListGather : public OpKernel { + public: + typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>> + ConstMatrixVector; + explicit TensorListGather(OpKernelConstruction* c) : OpKernel(c) { + OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_)); + } + + void Compute(OpKernelContext* c) override { + const TensorList* l = c->input(0).scalar<Variant>()().get<TensorList>(); + OP_REQUIRES(c, l != nullptr, + errors::InvalidArgument( + "Input handle is not a list. Saw: '", + c->input(0).scalar<Variant>()().DebugString(), "'")); + OP_REQUIRES(c, element_dtype_ == l->element_dtype, + errors::InvalidArgument("Invalid data types; op elements ", + DataTypeString(element_dtype_), + " but list elements ", + DataTypeString(l->element_dtype))); + OP_REQUIRES(c, l->element_shape.IsFullyDefined(), + errors::InvalidArgument("Tried to stack elements from a list " + "with non-fully-defined shape: ", + l->element_shape.DebugString())); + Tensor indices = c->input(1); + TensorShape resulting_shape; + resulting_shape.AddDim(indices.NumElements()); + for (TensorShapeDim s : l->element_shape) { + resulting_shape.AddDim(s.size); + } + Tensor* output; + OP_REQUIRES_OK(c, c->allocate_output(0, resulting_shape, &output)); + if (output->NumElements() == 0) { + return; + } + + ConstMatrixVector inputs_flat; + inputs_flat.reserve(l->tensors.size()); + for (int index = 0; index < indices.NumElements(); ++index) { + const int i = indices.flat<int32>()(index); + OP_REQUIRES( + c, i < l->tensors.size(), + errors::InvalidArgument("Index ", i, " out o range; list only has ", + l->tensors.size(), " elements.")); + const Tensor& t = l->tensors[i]; + OP_REQUIRES(c, l->element_shape.IsCompatibleWith(t.shape()), + errors::InvalidArgument( + "Tensor with invalid shape in list. List element shape: ", + l->element_shape.DebugString(), + " and tensor shape: ", t.shape().DebugString())); + inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix( + t.shaped<T, 2>({1, t.NumElements()}))); + } + auto output_flat = output->shaped<T, 2>({1, output->NumElements()}); + +#if GOOGLE_CUDA + if (std::is_same<Device, Eigen::GpuDevice>::value) { + ConcatGPU<T>(c, inputs_flat, output, &output_flat); + return; + } +#endif // GOOGLE_CUDA + ConcatCPU<T>(c->device(), inputs_flat, &output_flat); + } + + private: + DataType element_dtype_; +}; + +template <typename Device, typename T> class TensorListFromTensor : public OpKernel { public: TensorListFromTensor(OpKernelConstruction* c) : OpKernel(c) {} @@ -178,6 +246,59 @@ class TensorListFromTensor : public OpKernel { } }; +template <typename Device, typename T> +class TensorListScatter : public OpKernel { + public: + TensorListScatter(OpKernelConstruction* c) : OpKernel(c) {} + + void Compute(OpKernelContext* c) override { + Tensor* output_tensor; + AllocatorAttributes attr; + attr.set_on_host(true); + OP_REQUIRES_OK(c, c->allocate_output(0, {}, &output_tensor, attr)); + Tensor indices = c->input(1); + PartialTensorShape element_shape; + OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(2), &element_shape)); + TensorList output_list; + const Tensor& t = c->input(0); + output_list.element_dtype = t.dtype(); + OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(t.shape()), + errors::InvalidArgument( + "Tensor must be at least a vector, but saw shape: ", + t.shape().DebugString())); + TensorShape output_shape(t.shape()); + output_shape.RemoveDim(0); + OP_REQUIRES(c, element_shape.IsCompatibleWith(output_shape), + errors::InvalidArgument( + "Specified a list with shape ", element_shape.DebugString(), + " from a tensor with shape ", output_shape.DebugString())); + output_list.element_shape = element_shape; + output_list.tensors.reserve(indices.NumElements()); + for (int index = 0; index < indices.NumElements(); ++index) { + const int i = indices.flat<int32>()(index); + OP_REQUIRES(c, i < t.shape().dim_size(0), + errors::InvalidArgument("Trying to scatter index ", i, + " from tensor with ", + t.shape().dim_size(0), " rows.")); + Tensor tmp = t.Slice(i, i + 1); + TensorShape tmp_shape = tmp.shape(); + tmp_shape.RemoveDim(0); + OP_REQUIRES(c, tmp.CopyFrom(tmp, tmp_shape), + errors::Unknown("Unexpected shape error.")); + // TODO(apassos) maybe not always align; but weird compiler bugs seem to + // prevent this. + Tensor aligned; + OP_REQUIRES_OK(c, c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned)); + // TODO(apassos) do all slices in a single kernel invocation instead of + // many small ondes. + aligned.flat<T>().device(c->eigen_device<Device>()) = + tmp.unaligned_flat<T>(); + output_list.tensors.push_back(aligned); + } + output_tensor->scalar<Variant>()() = std::move(output_list); + } +}; + template <typename Device> Status TensorListBinaryAdd(OpKernelContext* c, const TensorList& a, const TensorList& b, TensorList* out) { diff --git a/tensorflow/core/ops/list_ops.cc b/tensorflow/core/ops/list_ops.cc index b9f94ba1c5..7d79df9c1c 100644 --- a/tensorflow/core/ops/list_ops.cc +++ b/tensorflow/core/ops/list_ops.cc @@ -210,7 +210,8 @@ REGISTER_OP("TensorListFromTensor") shape_inference::ShapeHandle o; TF_RETURN_IF_ERROR(c->Subshape(s, 1, &o)); shape_inference::ShapeHandle element_shape; - TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &element_shape)); + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape( + 1, &element_shape)); TF_RETURN_IF_ERROR(c->Merge(o, element_shape, &o)); c->set_output_handle_shapes_and_types( 0, std::vector<shape_inference::ShapeAndType>{{element_shape, t}}); @@ -240,7 +241,8 @@ REGISTER_OP("TensorListReserve") .SetShapeFn([](shape_inference::InferenceContext* c) { c->set_output(0, c->Scalar()); shape_inference::ShapeHandle s; - TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); + TF_RETURN_IF_ERROR( + c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(0, &s)); DataType t; TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &t)); c->set_output_handle_shapes_and_types( @@ -295,6 +297,51 @@ REGISTER_OP("TensorListSetItem") return Status::OK(); }); +REGISTER_OP("TensorListGather") + .Input("input_handle: variant") + .Input("indices: int32") + .Output("values: element_dtype") + .Attr("element_dtype: type") + .SetShapeFn([](shape_inference::InferenceContext* c) { + DataType t; + TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &t)); + auto* handle_data = c->input_handle_shapes_and_types(0); + shape_inference::ShapeHandle element_shape = c->UnknownShape(); + if (handle_data != nullptr) { + const shape_inference::ShapeAndType& list_shape_type = + (*handle_data)[0]; + element_shape = list_shape_type.shape; + if (list_shape_type.dtype != t) { + return errors::InvalidArgument("Expected list with element dtype ", + DataTypeString(t), + " but got list with element dtype ", + DataTypeString(list_shape_type.dtype)); + } + } + shape_inference::ShapeHandle out; + TF_RETURN_IF_ERROR(c->Concatenate(c->input(1), element_shape, &out)); + c->set_output(0, out); + return Status::OK(); + }); + +REGISTER_OP("TensorListScatter") + .Input("tensor: element_dtype") + .Input("indices: int32") + .Input("element_shape: shape_type") + .Output("output_handle: variant") + .Attr("element_dtype: type") + .Attr("shape_type: {int32, int64}") + .SetShapeFn([](shape_inference::InferenceContext* c) { + DataType t; + TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &t)); + shape_inference::ShapeHandle s; + TF_RETURN_IF_ERROR( + c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(2, &s)); + c->set_output_handle_shapes_and_types(0, {{s, t}}); + c->set_output(0, c->Scalar()); + return Status::OK(); + }); + REGISTER_OP("TensorListConcatLists") .Input("input_a: variant") .Input("input_b: variant") |