aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-09 12:24:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 12:27:54 -0700
commit0c6baae5af46bb22ea52db724e2194845d3bbf8c (patch)
tree317205fa8cdc75b1d9382df73aae65817528ae7f
parent8c2a52b26f21167ed0fcec7859850e38d0c216f9 (diff)
Add RaggedTensors to tf.core. Moving the RaggedGather op kernel.
PiperOrigin-RevId: 216400726
-rw-r--r--tensorflow/core/BUILD15
-rw-r--r--tensorflow/core/api_def/base_api/api_def_RaggedGather.pbtxt81
-rw-r--r--tensorflow/core/kernels/BUILD31
-rw-r--r--tensorflow/core/kernels/ragged_gather_op.cc292
-rw-r--r--tensorflow/core/kernels/ragged_gather_op_test.cc281
-rw-r--r--tensorflow/core/ops/ragged_array_ops.cc85
6 files changed, 785 insertions, 0 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index acea8e2217..9e7806342a 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -1155,6 +1155,19 @@ tf_gen_op_libs(
)
cc_library(
+ name = "ragged_ops",
+ deps = [
+ ":ragged_array_ops_op_lib",
+ ],
+)
+
+tf_gen_op_libs(
+ op_lib_names = [
+ "ragged_array_ops",
+ ],
+)
+
+cc_library(
name = "ops",
visibility = ["//visibility:public"],
deps = [
@@ -1187,6 +1200,7 @@ cc_library(
":nn_ops_op_lib",
":no_op_op_lib",
":parsing_ops_op_lib",
+ ":ragged_ops",
":random_ops_op_lib",
":remote_fused_graph_ops_op_lib",
":resource_variable_ops_op_lib",
@@ -1340,6 +1354,7 @@ cc_library(
"//tensorflow/core/kernels:parameterized_truncated_normal_op",
"//tensorflow/core/kernels:parsing",
"//tensorflow/core/kernels:partitioned_function_ops",
+ "//tensorflow/core/kernels:ragged_ops",
"//tensorflow/core/kernels:random_ops",
"//tensorflow/core/kernels:random_poisson_op",
"//tensorflow/core/kernels:remote_fused_graph_ops",
diff --git a/tensorflow/core/api_def/base_api/api_def_RaggedGather.pbtxt b/tensorflow/core/api_def/base_api/api_def_RaggedGather.pbtxt
new file mode 100644
index 0000000000..240c987dda
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_RaggedGather.pbtxt
@@ -0,0 +1,81 @@
+op {
+ graph_op_name: "RaggedGather"
+ visibility: HIDDEN
+ in_arg {
+ name: "params_nested_splits"
+ description: <<END
+The `nested_row_splits` tensors that define the row-partitioning for the
+`params` RaggedTensor input.
+END
+ }
+ in_arg {
+ name: "params_dense_values"
+ description: <<END
+The `inner_values` for the `params` RaggedTensor. There was a terminology change
+at the python level from dense_values to inner_values, so dense_values is the
+deprecated name.
+END
+ }
+ in_arg {
+ name: "indices"
+ description: <<END
+Indices in the outermost dimension of `params` of the values that should be
+gathered.
+END
+ }
+ out_arg {
+ name: "output_nested_splits"
+ description: <<END
+The `nested_row_splits` tensors that define the row-partitioning for the
+returned RaggedTensor.
+END
+ }
+ out_arg {
+ name: "output_dense_values"
+ description: "The `inner_values` for the returned RaggedTensor."
+ }
+ attr {
+ name: "PARAMS_RAGGED_RANK"
+ description: <<END
+The ragged rank of the `params` RaggedTensor. `params_nested_splits` should
+contain this number of `row_splits` tensors. This value should equal
+`params.ragged_rank`.
+END
+ }
+ attr {
+ name: "OUTPUT_RAGGED_RANK"
+ description: <<END
+The ragged rank of the output RaggedTensor. `output_nested_splits` will contain
+this number of `row_splits` tensors. This value should equal
+`indices.shape.ndims + params.ragged_rank - 1`.
+END
+ }
+ summary: <<END
+Gather ragged slices from `params` axis `0` according to `indices`.
+END
+ description: <<END
+Outputs a `RaggedTensor` output composed from `output_dense_values` and
+`output_nested_splits`, such that:
+
+```python
+output.shape = indices.shape + params.shape[1:]
+output.ragged_rank = indices.shape.ndims + params.ragged_rank
+output[i...j, d0...dn] = params[indices[i...j], d0...dn]
+```
+
+where
+
+* `params =
+ ragged.from_nested_row_splits(params_dense_values, params_nested_splits)`
+ provides the values that should be gathered.
+* `indices` ia a dense tensor with dtype `int32` or `int64`, indicating which
+ values should be gathered.
+* `output =
+ ragged.from_nested_row_splits(output_dense_values, output_nested_splits)`
+ is the output tensor.
+
+(Note: This c++ op is used to implement the higher-level python
+`tf.ragged.gather` op, which also supports ragged indices.)
+
+END
+}
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 3a920f26f3..1ca9c7b7f5 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -958,6 +958,37 @@ tf_kernel_library(
]) + ARRAY_DEPS,
)
+cc_library(
+ name = "ragged_ops",
+ deps = [
+ ":ragged_gather_op",
+ ],
+)
+
+tf_kernel_library(
+ name = "ragged_gather_op",
+ srcs = ["ragged_gather_op.cc"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:ragged_array_ops_op_lib",
+ ],
+)
+
+tf_cc_test(
+ name = "ragged_gather_op_test",
+ size = "small",
+ srcs = ["ragged_gather_op_test.cc"],
+ deps = [
+ ":ragged_gather_op",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:ragged_array_ops_op_lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/kernels:ops_testutil",
+ ],
+)
+
tf_kernel_library(
name = "cudnn_rnn_kernels",
srcs = ["cudnn_rnn_ops.cc"],
diff --git a/tensorflow/core/kernels/ragged_gather_op.cc b/tensorflow/core/kernels/ragged_gather_op.cc
new file mode 100644
index 0000000000..b2a342f637
--- /dev/null
+++ b/tensorflow/core/kernels/ragged_gather_op.cc
@@ -0,0 +1,292 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <limits>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/util/util.h"
+
+namespace tensorflow {
+
+namespace {
+
+// For each slice in `(start, limit)` in `value_slices`, append
+// `params_dense_values_in[start:limit] to `values_out`. `value_size` indicates
+// the number of scalars contained in each value params_dense_values_in[i].
+template <typename VALUE_TYPE>
+void WriteValueSlices(const Tensor& params_dense_values_in,
+ const std::vector<std::pair<int64, int64>>& value_slices,
+ int64 value_size, Tensor* values_out) {
+ const auto& params_dense_values =
+ params_dense_values_in.flat_outer_dims<VALUE_TYPE, 2>();
+ auto values = values_out->flat_outer_dims<VALUE_TYPE, 2>();
+ int out_pos = 0;
+ for (const auto& slice : value_slices) {
+ for (int i = slice.first; i < slice.second; ++i) {
+ for (int j = 0; j < value_size; ++j) {
+ values(out_pos, j) = params_dense_values(i, j);
+ }
+ ++out_pos;
+ }
+ }
+}
+
+} // namespace
+
+template <typename INDEX_TYPE>
+class RaggedGatherOpBase : public OpKernel {
+ public:
+ using OpKernel::OpKernel;
+
+ void Compute(OpKernelContext* context) override {
+ // Get the input Tensors.
+ OpInputList params_nested_splits_in;
+ OP_REQUIRES_OK(context, context->input_list("params_nested_splits",
+ &params_nested_splits_in));
+ const Tensor& params_dense_values_in =
+ context->input(params_nested_splits_in.size());
+ const Tensor& indices_in =
+ context->input(params_nested_splits_in.size() + 1);
+
+ DCHECK_GT(params_nested_splits_in.size(), 0); // Enforced by REGISTER_OP.
+ int64 num_params = params_nested_splits_in[0].dim_size(0) - 1;
+ OP_REQUIRES_OK(context, ValidateIndices(indices_in, num_params));
+
+ OP_REQUIRES(context, params_dense_values_in.dims() > 0,
+ errors::InvalidArgument("params.rank must be nonzero"));
+ int64 num_params_dense_values = params_dense_values_in.dim_size(0);
+
+ // Calculate the `splits`, and store the value slices that we need to
+ // copy in `value_slices`.
+ std::vector<std::pair<int64, int64>> value_slices;
+ int64 num_values = 0;
+ std::vector<std::vector<int64>> out_splits;
+ OP_REQUIRES_OK(context, MakeSplits(indices_in, params_nested_splits_in,
+ num_params_dense_values, &out_splits,
+ &value_slices, &num_values));
+
+ // Write the output tensors.
+ OP_REQUIRES_OK(context, WriteSplits(out_splits, context));
+ OP_REQUIRES_OK(context,
+ WriteValues(params_dense_values_in, value_slices,
+ out_splits.size(), num_values, context));
+ }
+
+ private:
+ // Check if any indices are out-of-bounds.
+ ::tensorflow::Status ValidateIndices(const Tensor& indices_in,
+ int64 num_params) {
+ const auto& indices = indices_in.flat<INDEX_TYPE>();
+ for (int64 i = 0; i < indices.size(); ++i) {
+ int64 index = indices(i);
+ if (index < 0 || index >= num_params) {
+ return errors::InvalidArgument(
+ "indices", SliceDebugString(indices_in.shape(), i), " = ", index,
+ " is not in [0, ", num_params, ")");
+ }
+ }
+ return ::tensorflow::Status::OK();
+ }
+
+ // Construct the `splits` output tensors, encoded using a nested vector.
+ // Also find the slices of values that need to be copied, and store them
+ // in `value_slices`. The total number of values that will be copied (which
+ // we need for allocating the output values tensor) is stored in `num_values`.
+ ::tensorflow::Status MakeSplits(
+ const Tensor& indices_in, const OpInputList& params_nested_splits_in,
+ int64 num_params_dense_values,
+ std::vector<std::vector<int64>>* out_splits,
+ std::vector<std::pair<int64, int64>>* value_slices, int64* num_values) {
+ *num_values = 0;
+ value_slices->clear();
+
+ int num_splits = indices_in.dims() - 1 + params_nested_splits_in.size();
+ out_splits->assign(num_splits, {0});
+
+ // Get Eigen tensors.
+ const auto& indices = indices_in.flat<INDEX_TYPE>();
+ std::vector<TTypes<int64>::ConstFlat> params_nested_splits;
+ params_nested_splits.reserve(params_nested_splits_in.size());
+ for (const auto& splits_in : params_nested_splits_in) {
+ params_nested_splits.push_back(splits_in.flat<int64>());
+ }
+
+ TF_RETURN_IF_ERROR(
+ ValidateSplits(params_nested_splits, num_params_dense_values));
+
+ // Add `splits` that come from all but the last dimension of the dense
+ // Tensor `indices`. In particular, for each dimension D, we add a
+ // splits tensor whose values are:
+ // range(splits.shape[D]*splits.shape[D+1] + 1, step=splits.shape[D+1])
+ // E.g., if indices.shape=[5, 3] then we will add a splits tensor
+ // [0, 3, 6, 9, 12, 15], since the outermost dimension has 5 elements,
+ // each of which contains 3 values.
+ for (int dim = 0; dim < indices_in.dims() - 1; ++dim) {
+ int stride = indices_in.dim_size(dim + 1);
+ int index = stride;
+ for (int i = 0; i < indices_in.dim_size(dim); ++i) {
+ out_splits->at(dim).push_back(index);
+ index += stride;
+ }
+ }
+
+ // Add `splits` that come from `params_nested_splits`. Starting with the
+ // outermost ragged dimension (i.e., the first `splits` tensor), we work
+ // our way in, finding the range of values that should be copied. As we
+ // go, we update the output `splits` for each dimension with the appropriate
+ // values. In particular, the *lengths* of the slices from `param_splits`
+ // should be copied to generate corresponding slice lengths in the output
+ // splits. E.g., if we are copying a ragged row with length 4, then we
+ // should add a new split point to out_splits that is 4 greater than the
+ // previous split point in out_splits.
+ for (int i = 0; i < indices.size(); ++i) {
+ int start = indices(i);
+ int limit = indices(i) + 1;
+
+ // Copy splits.
+ for (int dim = 0; dim < params_nested_splits.size(); ++dim) {
+ const auto& splits = params_nested_splits[dim];
+ int out_dim = dim + indices_in.dims() - 1;
+ if (out_dim >= 0) {
+ int64 delta = out_splits->at(out_dim).back() - splits(start);
+ for (int j = start; j < limit; ++j) {
+ out_splits->at(out_dim).push_back(splits(j + 1) + delta);
+ }
+ }
+ start = splits(start);
+ limit = splits(limit);
+ }
+ if (limit != start) {
+ value_slices->emplace_back(start, limit);
+ *num_values += limit - start;
+ }
+ }
+ return ::tensorflow::Status::OK();
+ }
+
+ ::tensorflow::Status ValidateSplits(
+ const std::vector<TTypes<int64>::ConstFlat>& params_nested_splits,
+ int64 num_params_dense_values) {
+ // Validate
+ for (int dim = 0; dim < params_nested_splits.size(); ++dim) {
+ const auto& splits = params_nested_splits[dim];
+ int64 last_split = (dim == params_nested_splits.size() - 1)
+ ? num_params_dense_values
+ : params_nested_splits[dim + 1].size();
+ if (splits.size() == 0) {
+ return errors::InvalidArgument("Ragged splits may not be empty");
+ }
+ if (splits(0) < 0) {
+ return errors::InvalidArgument("Ragged splits must be non-negative");
+ }
+ if (splits(splits.size() - 1) > last_split) {
+ return errors::InvalidArgument(
+ "Ragged splits must not point past values");
+ }
+ for (int i = 1; i < splits.size(); ++i) {
+ if (splits(i - 1) > splits(i)) {
+ return errors::InvalidArgument("Ragged splits must be sorted");
+ }
+ }
+ }
+ return ::tensorflow::Status::OK();
+ }
+
+ ::tensorflow::Status WriteSplits(
+ const std::vector<std::vector<int64>>& out_splits,
+ OpKernelContext* context) {
+ OpOutputList splits_out;
+ TF_RETURN_IF_ERROR(
+ context->output_list("output_nested_splits", &splits_out));
+ for (int i = 0; i < out_splits.size(); ++i) {
+ Tensor* splits;
+ int64 num_splits = out_splits[i].size();
+ TF_RETURN_IF_ERROR(
+ splits_out.allocate(i, TensorShape({num_splits}), &splits));
+ auto splits_flat = splits->flat<int64>();
+ std::copy_n(out_splits[i].data(), out_splits[i].size(),
+ splits_flat.data());
+ }
+ return ::tensorflow::Status::OK();
+ }
+
+ ::tensorflow::Status WriteValues(
+ const Tensor& params_dense_values_in,
+ const std::vector<std::pair<int64, int64>>& value_slices,
+ int values_index, int64 num_values, OpKernelContext* context) const {
+ Tensor* values_out = nullptr;
+ TensorShape values_shape = params_dense_values_in.shape();
+ values_shape.set_dim(0, num_values);
+ TF_RETURN_IF_ERROR(
+ context->allocate_output(values_index, values_shape, &values_out));
+ int64 value_size = params_dense_values_in.NumElements() /
+ params_dense_values_in.dim_size(0);
+ CallWriteValueSlices(params_dense_values_in, value_slices, value_size,
+ values_out);
+ return ::tensorflow::Status::OK();
+ }
+
+ protected:
+ // Call WriteValueSlices() using the appropriate VALUE_TYPE template
+ // parameter. This pattern is used to reduce binary size. In particular,
+ // this allows us to have two instantiations of this class (one for each
+ // index type), rather than 14 (one for each index type and value type),
+ // which cuts the binary size of this op from ~300k to <90k.
+ virtual void CallWriteValueSlices(
+ const Tensor& params_dense_values_in,
+ const std::vector<std::pair<int64, int64>>& value_slices,
+ int64 value_size, Tensor* values_out) const = 0;
+};
+
+template <typename INDEX_TYPE, typename VALUE_TYPE>
+class RaggedGatherOp : public RaggedGatherOpBase<INDEX_TYPE> {
+ public:
+ using RaggedGatherOpBase<INDEX_TYPE>::RaggedGatherOpBase;
+
+ private:
+ void CallWriteValueSlices(
+ const Tensor& params_dense_values_in,
+ const std::vector<std::pair<int64, int64>>& value_slices,
+ int64 value_size, Tensor* values_out) const override {
+ WriteValueSlices<VALUE_TYPE>(params_dense_values_in, value_slices,
+ value_size, values_out);
+ }
+};
+
+#define REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(index_type, value_type) \
+ REGISTER_KERNEL_BUILDER(Name("RaggedGather") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<index_type>("Tindices") \
+ .TypeConstraint<value_type>("Tvalues"), \
+ RaggedGatherOp<index_type, value_type>);
+#define REGISTER_CPU_KERNEL(value_type) \
+ REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(int32, value_type) \
+ REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(int64, value_type)
+TF_CALL_POD_TYPES(REGISTER_CPU_KERNEL);
+TF_CALL_string(REGISTER_CPU_KERNEL);
+TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_KERNEL);
+TF_CALL_quint16(REGISTER_CPU_KERNEL);
+TF_CALL_qint16(REGISTER_CPU_KERNEL);
+TF_CALL_uint32(REGISTER_CPU_KERNEL);
+TF_CALL_uint64(REGISTER_CPU_KERNEL);
+#undef REGISTER_CPU_KERNEL
+#undef REGISTER_CPU_KERNEL_WITH_INDEX_TYPE
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/ragged_gather_op_test.cc b/tensorflow/core/kernels/ragged_gather_op_test.cc
new file mode 100644
index 0000000000..47be788151
--- /dev/null
+++ b/tensorflow/core/kernels/ragged_gather_op_test.cc
@@ -0,0 +1,281 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/shape_inference_testutil.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+class RaggedGatherOpTest : public ::tensorflow::OpsTestBase {
+ protected:
+ // Builds the tensorflow test graph for RaggedGather.
+ template <typename VALUE_TYPE, typename INDEX_TYPE>
+ void BuildRaggedGatherGraph(
+ const TensorShape& indices_shape, const std::vector<INDEX_TYPE>& indices,
+ const std::vector<std::vector<int64>>& params_nested_splits,
+ const TensorShape& params_dense_values_shape,
+ const gtl::ArraySlice<VALUE_TYPE> params_dense_values) {
+ const auto& value_dtype = DataTypeToEnum<VALUE_TYPE>::v();
+ const auto& index_dtype = DataTypeToEnum<INDEX_TYPE>::v();
+ int64 PARAMS_RAGGED_RANK = params_nested_splits.size();
+ int64 num_splits = PARAMS_RAGGED_RANK + indices_shape.dims() - 1;
+ TF_ASSERT_OK(
+ NodeDefBuilder("tested_op", "RaggedGather")
+ .Input(FakeInput(PARAMS_RAGGED_RANK)) // params_nested_splits
+ .Input(FakeInput(value_dtype)) // params_dense_values
+ .Input(FakeInput(index_dtype)) // indices
+ .Attr("PARAMS_RAGGED_RANK", PARAMS_RAGGED_RANK)
+ .Attr("OUTPUT_RAGGED_RANK", num_splits)
+ .Attr("Tvalues", value_dtype)
+ .Attr("Tindices", index_dtype)
+ .Finalize(node_def()));
+ TF_ASSERT_OK(InitOp());
+ for (const auto& splits : params_nested_splits) {
+ int64 splits_size = splits.size();
+ AddInputFromArray<int64>(TensorShape({splits_size}), splits);
+ }
+ AddInputFromArray<VALUE_TYPE>(params_dense_values_shape,
+ params_dense_values);
+ AddInputFromArray<INDEX_TYPE>(indices_shape, indices);
+ }
+};
+
+TEST_F(RaggedGatherOpTest, RaggedGather) {
+ // indices = [2, 1, 0, 3]
+ // params = [[.1, .2, .3], [], [.4, .5, .6, .7], [.8, .9]]
+ // params.shape = [4, None]
+ BuildRaggedGatherGraph<float, int32>(
+ TensorShape({4}), // indices.shape
+ {2, 1, 0, 3}, // indices
+ {{0, 3, 3, 7, 9}}, // params_nested_splits
+ TensorShape({9}), // params_dense_values.shape
+ {.1, .2, .3, .4, .5, .6, .7, .8, .9} // params_dense_values
+ );
+
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Expected: [[.4, .5, .6, .7], [.1, .2, .3], [], [.8, .9]]
+ test::ExpectTensorEqual<int64>(*GetOutput(0),
+ test::AsTensor<int64>({0, 4, 4, 7, 9}));
+ test::ExpectTensorNear<float>(
+ *GetOutput(1),
+ test::AsTensor<float>({.4, .5, .6, .7, .1, .2, .3, .8, .9}), 0.1);
+}
+
+TEST_F(RaggedGatherOpTest, RaggedGather_3DParams) {
+ // indices = [2, 1, 0, 2, 3]
+ // params = [[[]], [[.1, 2], [.3]], [], [[.4, .5], [.6, .7, .8]], [[.9]]]
+ // params.shape = [5, None, None]
+ BuildRaggedGatherGraph<float, int32>(
+ TensorShape({5}), // indices.shape
+ {2, 1, 0, 2, 3}, // indices
+ {{0, 1, 3, 3, 5, 6}, {0, 0, 2, 3, 5, 8, 9}}, // params_nested_splits
+ TensorShape({9}), // params_dense_values.shape
+ {.1, .2, .3, .4, .5, .6, .7, .8, .9} // params_dense_values
+ );
+
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Expected: [[], [[.1, 2], [.3]], [[]], [], [[.4, .5], [.6, .7, .8]]]
+ test::ExpectTensorEqual<int64>(*GetOutput(0),
+ test::AsTensor<int64>({0, 0, 2, 3, 3, 5}));
+ test::ExpectTensorEqual<int64>(*GetOutput(1),
+ test::AsTensor<int64>({0, 2, 3, 3, 5, 8}));
+ test::ExpectTensorNear<float>(
+ *GetOutput(2), test::AsTensor<float>({.1, .2, .3, .4, .5, .6, .7, .8}),
+ 0.1);
+}
+
+TEST_F(RaggedGatherOpTest, RaggedGather_4DParams) {
+ // indices = [2, 1, 0, 2]
+ // params = [[[]], [[[1, 2], [3, 4], [5, 6]], [[7, 8]]], []]
+ // params.shape = [4, None, None, 2]
+ BuildRaggedGatherGraph<int32, int32>(
+ TensorShape({4}), // indices.shape
+ {2, 1, 0, 2}, // indices
+ {{0, 1, 3, 3}, {0, 0, 3, 4}}, // params_nested_splits
+ TensorShape({4, 2}), // params_dense_values.shape
+ {1, 2, 3, 4, 5, 6, 7, 8} // params_dense_values
+ );
+
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Expected: [[],
+ // [[[1, 2], [3, 4], [5, 6]], [[7, 8]]],
+ // [[]],
+ // []]
+ test::ExpectTensorEqual<int64>(*GetOutput(0),
+ test::AsTensor<int64>({0, 0, 2, 3, 3}));
+ test::ExpectTensorEqual<int64>(*GetOutput(1),
+ test::AsTensor<int64>({0, 3, 4, 4}));
+ test::ExpectTensorEqual<int32>(
+ *GetOutput(2),
+ test::AsTensor<int32>({1, 2, 3, 4, 5, 6, 7, 8}, TensorShape({4, 2})));
+}
+
+TEST_F(RaggedGatherOpTest, RaggedGather_2DIndices) {
+ // indices = [[2, 1], [0, 3]]
+ // params = [[.1, .2, .3], [], [.4, .5, .6, .7], [.8, .9]]
+ BuildRaggedGatherGraph<float, int32>(
+ TensorShape({2, 2}), // indices.shape
+ {2, 1, 0, 3}, // indices
+ {{0, 3, 3, 7, 9}}, // params_nested_splits
+ TensorShape({9}), // params_dense_values.shape
+ {.1, .2, .3, .4, .5, .6, .7, .8, .9} // params_dense_values
+ );
+
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Expected: [ [ [.4, .5, .6, .7], [.1, .2, .3] ],
+ // [ [], [.8, .9] ] ]
+ test::ExpectTensorEqual<int64>(*GetOutput(0),
+ test::AsTensor<int64>({0, 2, 4}));
+ test::ExpectTensorEqual<int64>(*GetOutput(1),
+ test::AsTensor<int64>({0, 4, 4, 7, 9}));
+ test::ExpectTensorNear<float>(
+ *GetOutput(2),
+ test::AsTensor<float>({.4, .5, .6, .7, .1, .2, .3, .8, .9}), 0.1);
+}
+
+TEST_F(RaggedGatherOpTest, RaggedGather_ScalarIndices) {
+ // indices = 2
+ // params = [[.1, .2, .3], [], [.4, .5, .6, .7], [.8, .9]]
+ BuildRaggedGatherGraph<float, int32>(
+ TensorShape({}), // indices.shape
+ {2}, // indices
+ {{0, 3, 3, 7, 9}}, // params_nested_splits
+ TensorShape({9}), // params_dense_values.shape
+ {.1, .2, .3, .4, .5, .6, .7, .8, .9} // params_dense_values
+ );
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Expected: [.4, .5, .6, .7]
+ test::ExpectTensorNear<float>(*GetOutput(0),
+ test::AsTensor<float>({.4, .5, .6, .7}), 0.1);
+}
+
+TEST_F(RaggedGatherOpTest, RaggedGather_OutOfBounds) {
+ // indices = [2, 10]
+ // params = [[.1, .2, .3], [], [.4, .5, .6, .7], [.8, .9]]
+ BuildRaggedGatherGraph<float, int32>(
+ TensorShape({2}), // indices.shape
+ {2, 10}, // indices
+ {{0, 3, 3, 7, 9}}, // params_nested_splits
+ TensorShape({9}), // params_dense_values.shape
+ {.1, .2, .3, .4, .5, .6, .7, .8, .9} // params_dense_values
+ );
+ EXPECT_EQ("indices[1] = 10 is not in [0, 4)", RunOpKernel().error_message());
+}
+
+TEST_F(RaggedGatherOpTest, InvalidSplitsNotSorted) {
+ BuildRaggedGatherGraph<float, int32>(
+ TensorShape({2}), // indices.shape
+ {0, 2}, // indices
+ {{0, 3, 5, 2, 9}}, // params_nested_splits
+ TensorShape({9}), // params_dense_values.shape
+ {.1, .2, .3, .4, .5, .6, .7, .8, .9} // params_dense_values
+ );
+ EXPECT_EQ("Ragged splits must be sorted", RunOpKernel().error_message());
+}
+
+TEST_F(RaggedGatherOpTest, InvalidSplitsNegative) {
+ BuildRaggedGatherGraph<float, int32>(
+ TensorShape({2}), // indices.shape
+ {0, 2}, // indices
+ {{-1, 3, 2, 7, 9}}, // params_nested_splits
+ TensorShape({9}), // params_dense_values.shape
+ {.1, .2, .3, .4, .5, .6, .7, .8, .9} // params_dense_values
+ );
+ EXPECT_EQ("Ragged splits must be non-negative",
+ RunOpKernel().error_message());
+}
+
+TEST_F(RaggedGatherOpTest, InvalidSplitsEmpty) {
+ BuildRaggedGatherGraph<float, int32>(
+ TensorShape({0}), // indices.shape
+ {}, // indices
+ {{}}, // params_nested_splits
+ TensorShape({0}), // params_dense_values.shape
+ {} // params_dense_values
+ );
+ EXPECT_EQ("Ragged splits may not be empty", RunOpKernel().error_message());
+}
+
+TEST_F(RaggedGatherOpTest, InvalidSplitsTooBig) {
+ BuildRaggedGatherGraph<float, int32>(
+ TensorShape({2}), // indices.shape
+ {0, 2}, // indices
+ {{0, 20, 40, 80, 100}}, // params_nested_splits
+ TensorShape({9}), // params_dense_values.shape
+ {.1, .2, .3, .4, .5, .6, .7, .8, .9} // params_dense_values
+ );
+ EXPECT_EQ("Ragged splits must not point past values",
+ RunOpKernel().error_message());
+}
+
+TEST_F(RaggedGatherOpTest, BadValuesShape) {
+ BuildRaggedGatherGraph<float, int32>(
+ TensorShape({0}), // indices.shape
+ {}, // indices
+ {{0}}, // params_nested_splits
+ TensorShape({}), // params_dense_values.shape
+ {.1} // params_dense_values
+ );
+ EXPECT_EQ("params.rank must be nonzero", RunOpKernel().error_message());
+}
+
+TEST_F(RaggedGatherOpTest, ShapeFn) {
+ // RaggedGather(param_splits+, param_values, indices) -> [splits+, values]
+ ShapeInferenceTestOp op("RaggedGather");
+
+ (*op.node_def.mutable_attr())["PARAMS_RAGGED_RANK"].set_i(1);
+ (*op.node_def.mutable_attr())["OUTPUT_RAGGED_RANK"].set_i(1);
+ INFER_OK(op, "?;?;?", "[?];?");
+ INFER_OK(op, "[?];[?];[?]", "[?];[?]");
+ INFER_OK(op, "[?];[?,?,?];[?]", "[?];[?,d1_1,d1_2]");
+ INFER_OK(op, "[5];[10];[15]", "[?];[?]");
+ INFER_OK(op, "[5];[10,2];[15]", "[?];[?,d1_1]");
+ INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[5];[];[]");
+ INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[1,2];[];[5]");
+
+ (*op.node_def.mutable_attr())["PARAMS_RAGGED_RANK"].set_i(2);
+ (*op.node_def.mutable_attr())["OUTPUT_RAGGED_RANK"].set_i(2);
+ INFER_OK(op, "?;?;?;?", "[?];[?];?");
+ INFER_OK(op, "[?];[?];[?];[?]", "[?];[?];[?]");
+ INFER_OK(op, "[?];[?];[?,?,?];[?]", "[?];[?];[?,d2_1,d2_2]");
+ INFER_OK(op, "[5];[10];[15];[20]", "[?];[?];[?]");
+
+ (*op.node_def.mutable_attr())["PARAMS_RAGGED_RANK"].set_i(1);
+ (*op.node_def.mutable_attr())["OUTPUT_RAGGED_RANK"].set_i(2);
+ INFER_OK(op, "?;?;?", "[?];[?];?");
+ INFER_OK(op, "[?];[?];[?,?]", "[?];[?];[?]");
+ INFER_OK(op, "[?];[?,?,?];[?,?]", "[?];[?];[?,d1_1,d1_2]");
+ INFER_OK(op, "[15];[20];[5,10]", "[?];[?];[?]");
+ INFER_OK(op, "[15];[20,2];[5,10]", "[?];[?];[?,d1_1]");
+
+ (*op.node_def.mutable_attr())["PARAMS_RAGGED_RANK"].set_i(1);
+ (*op.node_def.mutable_attr())["OUTPUT_RAGGED_RANK"].set_i(0);
+ INFER_OK(op, "[?];[?];[]", "[?]");
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/ragged_array_ops.cc b/tensorflow/core/ops/ragged_array_ops.cc
new file mode 100644
index 0000000000..4642579939
--- /dev/null
+++ b/tensorflow/core/ops/ragged_array_ops.cc
@@ -0,0 +1,85 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+using shape_inference::DimensionHandle;
+using shape_inference::InferenceContext;
+using shape_inference::ShapeHandle;
+
+Status RaggedGatherShapeFn(InferenceContext* c);
+
+//==============================================================================
+// Registered Ops
+//==============================================================================
+
+REGISTER_OP("RaggedGather")
+ .Input("params_nested_splits: PARAMS_RAGGED_RANK * int64")
+ .Input("params_dense_values: Tvalues")
+ .Input("indices: Tindices")
+ .Output("output_nested_splits: OUTPUT_RAGGED_RANK * int64")
+ .Output("output_dense_values: Tvalues")
+ .Attr("Tvalues: type")
+ .Attr("Tindices: {int32, int64}")
+ .Attr("PARAMS_RAGGED_RANK: int >= 1")
+ .Attr("OUTPUT_RAGGED_RANK: int >= 0")
+ .SetShapeFn(RaggedGatherShapeFn);
+
+//==============================================================================
+// Shape Functions
+//==============================================================================
+
+Status RaggedGatherShapeFn(InferenceContext* c) {
+ int num_splits;
+ int64 PARAMS_RAGGED_RANK;
+ TF_RETURN_IF_ERROR(
+ c->GetAttr<int64>("PARAMS_RAGGED_RANK", &PARAMS_RAGGED_RANK));
+ TF_RETURN_IF_ERROR(c->GetAttr<int>("OUTPUT_RAGGED_RANK", &num_splits));
+
+ // Check rank of `indices`.
+ ShapeHandle indices = c->input(PARAMS_RAGGED_RANK + 1);
+ TF_RETURN_IF_ERROR(
+ c->WithRank(indices, num_splits - PARAMS_RAGGED_RANK + 1, &indices));
+
+ // Check that all params_nested_splits have rank 1.
+ for (int64 i = 0; i < PARAMS_RAGGED_RANK; ++i) {
+ ShapeHandle splits = c->input(i);
+ TF_RETURN_IF_ERROR(c->WithRank(splits, 1, &splits));
+ }
+
+ // Check that `params_dense_values` has rank>=1.
+ ShapeHandle params_dense_values = c->input(PARAMS_RAGGED_RANK);
+ TF_RETURN_IF_ERROR(
+ c->WithRankAtLeast(params_dense_values, 1, &params_dense_values));
+
+ // Set the rank for the `splits` outputs.
+ for (int i = 0; i < num_splits; ++i) {
+ c->set_output(i, c->UnknownShapeOfRank(1));
+ }
+
+ // Calculate the `values` shape.
+ ShapeHandle value = c->UnknownShape();
+ ShapeHandle values = c->UnknownShape();
+ TF_RETURN_IF_ERROR(c->Subshape(params_dense_values, 1, &value));
+ TF_RETURN_IF_ERROR(c->Concatenate(c->UnknownShapeOfRank(1), value, &values));
+ c->set_output(num_splits, values);
+
+ return Status::OK();
+}
+
+} // namespace tensorflow