From 0c6baae5af46bb22ea52db724e2194845d3bbf8c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 9 Oct 2018 12:24:05 -0700 Subject: Add RaggedTensors to tf.core. Moving the RaggedGather op kernel. PiperOrigin-RevId: 216400726 --- tensorflow/core/BUILD | 15 ++ .../api_def/base_api/api_def_RaggedGather.pbtxt | 81 ++++++ tensorflow/core/kernels/BUILD | 31 +++ tensorflow/core/kernels/ragged_gather_op.cc | 292 +++++++++++++++++++++ tensorflow/core/kernels/ragged_gather_op_test.cc | 281 ++++++++++++++++++++ tensorflow/core/ops/ragged_array_ops.cc | 85 ++++++ 6 files changed, 785 insertions(+) create mode 100644 tensorflow/core/api_def/base_api/api_def_RaggedGather.pbtxt create mode 100644 tensorflow/core/kernels/ragged_gather_op.cc create mode 100644 tensorflow/core/kernels/ragged_gather_op_test.cc create mode 100644 tensorflow/core/ops/ragged_array_ops.cc diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index acea8e2217..9e7806342a 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1154,6 +1154,19 @@ tf_gen_op_libs( ], ) +cc_library( + name = "ragged_ops", + deps = [ + ":ragged_array_ops_op_lib", + ], +) + +tf_gen_op_libs( + op_lib_names = [ + "ragged_array_ops", + ], +) + cc_library( name = "ops", visibility = ["//visibility:public"], @@ -1187,6 +1200,7 @@ cc_library( ":nn_ops_op_lib", ":no_op_op_lib", ":parsing_ops_op_lib", + ":ragged_ops", ":random_ops_op_lib", ":remote_fused_graph_ops_op_lib", ":resource_variable_ops_op_lib", @@ -1340,6 +1354,7 @@ cc_library( "//tensorflow/core/kernels:parameterized_truncated_normal_op", "//tensorflow/core/kernels:parsing", "//tensorflow/core/kernels:partitioned_function_ops", + "//tensorflow/core/kernels:ragged_ops", "//tensorflow/core/kernels:random_ops", "//tensorflow/core/kernels:random_poisson_op", "//tensorflow/core/kernels:remote_fused_graph_ops", diff --git a/tensorflow/core/api_def/base_api/api_def_RaggedGather.pbtxt b/tensorflow/core/api_def/base_api/api_def_RaggedGather.pbtxt new file mode 100644 index 0000000000..240c987dda --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_RaggedGather.pbtxt @@ -0,0 +1,81 @@ +op { + graph_op_name: "RaggedGather" + visibility: HIDDEN + in_arg { + name: "params_nested_splits" + description: < +#include +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/util/util.h" + +namespace tensorflow { + +namespace { + +// For each slice in `(start, limit)` in `value_slices`, append +// `params_dense_values_in[start:limit] to `values_out`. `value_size` indicates +// the number of scalars contained in each value params_dense_values_in[i]. +template +void WriteValueSlices(const Tensor& params_dense_values_in, + const std::vector>& value_slices, + int64 value_size, Tensor* values_out) { + const auto& params_dense_values = + params_dense_values_in.flat_outer_dims(); + auto values = values_out->flat_outer_dims(); + int out_pos = 0; + for (const auto& slice : value_slices) { + for (int i = slice.first; i < slice.second; ++i) { + for (int j = 0; j < value_size; ++j) { + values(out_pos, j) = params_dense_values(i, j); + } + ++out_pos; + } + } +} + +} // namespace + +template +class RaggedGatherOpBase : public OpKernel { + public: + using OpKernel::OpKernel; + + void Compute(OpKernelContext* context) override { + // Get the input Tensors. + OpInputList params_nested_splits_in; + OP_REQUIRES_OK(context, context->input_list("params_nested_splits", + ¶ms_nested_splits_in)); + const Tensor& params_dense_values_in = + context->input(params_nested_splits_in.size()); + const Tensor& indices_in = + context->input(params_nested_splits_in.size() + 1); + + DCHECK_GT(params_nested_splits_in.size(), 0); // Enforced by REGISTER_OP. + int64 num_params = params_nested_splits_in[0].dim_size(0) - 1; + OP_REQUIRES_OK(context, ValidateIndices(indices_in, num_params)); + + OP_REQUIRES(context, params_dense_values_in.dims() > 0, + errors::InvalidArgument("params.rank must be nonzero")); + int64 num_params_dense_values = params_dense_values_in.dim_size(0); + + // Calculate the `splits`, and store the value slices that we need to + // copy in `value_slices`. + std::vector> value_slices; + int64 num_values = 0; + std::vector> out_splits; + OP_REQUIRES_OK(context, MakeSplits(indices_in, params_nested_splits_in, + num_params_dense_values, &out_splits, + &value_slices, &num_values)); + + // Write the output tensors. + OP_REQUIRES_OK(context, WriteSplits(out_splits, context)); + OP_REQUIRES_OK(context, + WriteValues(params_dense_values_in, value_slices, + out_splits.size(), num_values, context)); + } + + private: + // Check if any indices are out-of-bounds. + ::tensorflow::Status ValidateIndices(const Tensor& indices_in, + int64 num_params) { + const auto& indices = indices_in.flat(); + for (int64 i = 0; i < indices.size(); ++i) { + int64 index = indices(i); + if (index < 0 || index >= num_params) { + return errors::InvalidArgument( + "indices", SliceDebugString(indices_in.shape(), i), " = ", index, + " is not in [0, ", num_params, ")"); + } + } + return ::tensorflow::Status::OK(); + } + + // Construct the `splits` output tensors, encoded using a nested vector. + // Also find the slices of values that need to be copied, and store them + // in `value_slices`. The total number of values that will be copied (which + // we need for allocating the output values tensor) is stored in `num_values`. + ::tensorflow::Status MakeSplits( + const Tensor& indices_in, const OpInputList& params_nested_splits_in, + int64 num_params_dense_values, + std::vector>* out_splits, + std::vector>* value_slices, int64* num_values) { + *num_values = 0; + value_slices->clear(); + + int num_splits = indices_in.dims() - 1 + params_nested_splits_in.size(); + out_splits->assign(num_splits, {0}); + + // Get Eigen tensors. + const auto& indices = indices_in.flat(); + std::vector::ConstFlat> params_nested_splits; + params_nested_splits.reserve(params_nested_splits_in.size()); + for (const auto& splits_in : params_nested_splits_in) { + params_nested_splits.push_back(splits_in.flat()); + } + + TF_RETURN_IF_ERROR( + ValidateSplits(params_nested_splits, num_params_dense_values)); + + // Add `splits` that come from all but the last dimension of the dense + // Tensor `indices`. In particular, for each dimension D, we add a + // splits tensor whose values are: + // range(splits.shape[D]*splits.shape[D+1] + 1, step=splits.shape[D+1]) + // E.g., if indices.shape=[5, 3] then we will add a splits tensor + // [0, 3, 6, 9, 12, 15], since the outermost dimension has 5 elements, + // each of which contains 3 values. + for (int dim = 0; dim < indices_in.dims() - 1; ++dim) { + int stride = indices_in.dim_size(dim + 1); + int index = stride; + for (int i = 0; i < indices_in.dim_size(dim); ++i) { + out_splits->at(dim).push_back(index); + index += stride; + } + } + + // Add `splits` that come from `params_nested_splits`. Starting with the + // outermost ragged dimension (i.e., the first `splits` tensor), we work + // our way in, finding the range of values that should be copied. As we + // go, we update the output `splits` for each dimension with the appropriate + // values. In particular, the *lengths* of the slices from `param_splits` + // should be copied to generate corresponding slice lengths in the output + // splits. E.g., if we are copying a ragged row with length 4, then we + // should add a new split point to out_splits that is 4 greater than the + // previous split point in out_splits. + for (int i = 0; i < indices.size(); ++i) { + int start = indices(i); + int limit = indices(i) + 1; + + // Copy splits. + for (int dim = 0; dim < params_nested_splits.size(); ++dim) { + const auto& splits = params_nested_splits[dim]; + int out_dim = dim + indices_in.dims() - 1; + if (out_dim >= 0) { + int64 delta = out_splits->at(out_dim).back() - splits(start); + for (int j = start; j < limit; ++j) { + out_splits->at(out_dim).push_back(splits(j + 1) + delta); + } + } + start = splits(start); + limit = splits(limit); + } + if (limit != start) { + value_slices->emplace_back(start, limit); + *num_values += limit - start; + } + } + return ::tensorflow::Status::OK(); + } + + ::tensorflow::Status ValidateSplits( + const std::vector::ConstFlat>& params_nested_splits, + int64 num_params_dense_values) { + // Validate + for (int dim = 0; dim < params_nested_splits.size(); ++dim) { + const auto& splits = params_nested_splits[dim]; + int64 last_split = (dim == params_nested_splits.size() - 1) + ? num_params_dense_values + : params_nested_splits[dim + 1].size(); + if (splits.size() == 0) { + return errors::InvalidArgument("Ragged splits may not be empty"); + } + if (splits(0) < 0) { + return errors::InvalidArgument("Ragged splits must be non-negative"); + } + if (splits(splits.size() - 1) > last_split) { + return errors::InvalidArgument( + "Ragged splits must not point past values"); + } + for (int i = 1; i < splits.size(); ++i) { + if (splits(i - 1) > splits(i)) { + return errors::InvalidArgument("Ragged splits must be sorted"); + } + } + } + return ::tensorflow::Status::OK(); + } + + ::tensorflow::Status WriteSplits( + const std::vector>& out_splits, + OpKernelContext* context) { + OpOutputList splits_out; + TF_RETURN_IF_ERROR( + context->output_list("output_nested_splits", &splits_out)); + for (int i = 0; i < out_splits.size(); ++i) { + Tensor* splits; + int64 num_splits = out_splits[i].size(); + TF_RETURN_IF_ERROR( + splits_out.allocate(i, TensorShape({num_splits}), &splits)); + auto splits_flat = splits->flat(); + std::copy_n(out_splits[i].data(), out_splits[i].size(), + splits_flat.data()); + } + return ::tensorflow::Status::OK(); + } + + ::tensorflow::Status WriteValues( + const Tensor& params_dense_values_in, + const std::vector>& value_slices, + int values_index, int64 num_values, OpKernelContext* context) const { + Tensor* values_out = nullptr; + TensorShape values_shape = params_dense_values_in.shape(); + values_shape.set_dim(0, num_values); + TF_RETURN_IF_ERROR( + context->allocate_output(values_index, values_shape, &values_out)); + int64 value_size = params_dense_values_in.NumElements() / + params_dense_values_in.dim_size(0); + CallWriteValueSlices(params_dense_values_in, value_slices, value_size, + values_out); + return ::tensorflow::Status::OK(); + } + + protected: + // Call WriteValueSlices() using the appropriate VALUE_TYPE template + // parameter. This pattern is used to reduce binary size. In particular, + // this allows us to have two instantiations of this class (one for each + // index type), rather than 14 (one for each index type and value type), + // which cuts the binary size of this op from ~300k to <90k. + virtual void CallWriteValueSlices( + const Tensor& params_dense_values_in, + const std::vector>& value_slices, + int64 value_size, Tensor* values_out) const = 0; +}; + +template +class RaggedGatherOp : public RaggedGatherOpBase { + public: + using RaggedGatherOpBase::RaggedGatherOpBase; + + private: + void CallWriteValueSlices( + const Tensor& params_dense_values_in, + const std::vector>& value_slices, + int64 value_size, Tensor* values_out) const override { + WriteValueSlices(params_dense_values_in, value_slices, + value_size, values_out); + } +}; + +#define REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(index_type, value_type) \ + REGISTER_KERNEL_BUILDER(Name("RaggedGather") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("Tindices") \ + .TypeConstraint("Tvalues"), \ + RaggedGatherOp); +#define REGISTER_CPU_KERNEL(value_type) \ + REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(int32, value_type) \ + REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(int64, value_type) +TF_CALL_POD_TYPES(REGISTER_CPU_KERNEL); +TF_CALL_string(REGISTER_CPU_KERNEL); +TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_KERNEL); +TF_CALL_quint16(REGISTER_CPU_KERNEL); +TF_CALL_qint16(REGISTER_CPU_KERNEL); +TF_CALL_uint32(REGISTER_CPU_KERNEL); +TF_CALL_uint64(REGISTER_CPU_KERNEL); +#undef REGISTER_CPU_KERNEL +#undef REGISTER_CPU_KERNEL_WITH_INDEX_TYPE + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/ragged_gather_op_test.cc b/tensorflow/core/kernels/ragged_gather_op_test.cc new file mode 100644 index 0000000000..47be788151 --- /dev/null +++ b/tensorflow/core/kernels/ragged_gather_op_test.cc @@ -0,0 +1,281 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/shape_inference_testutil.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +class RaggedGatherOpTest : public ::tensorflow::OpsTestBase { + protected: + // Builds the tensorflow test graph for RaggedGather. + template + void BuildRaggedGatherGraph( + const TensorShape& indices_shape, const std::vector& indices, + const std::vector>& params_nested_splits, + const TensorShape& params_dense_values_shape, + const gtl::ArraySlice params_dense_values) { + const auto& value_dtype = DataTypeToEnum::v(); + const auto& index_dtype = DataTypeToEnum::v(); + int64 PARAMS_RAGGED_RANK = params_nested_splits.size(); + int64 num_splits = PARAMS_RAGGED_RANK + indices_shape.dims() - 1; + TF_ASSERT_OK( + NodeDefBuilder("tested_op", "RaggedGather") + .Input(FakeInput(PARAMS_RAGGED_RANK)) // params_nested_splits + .Input(FakeInput(value_dtype)) // params_dense_values + .Input(FakeInput(index_dtype)) // indices + .Attr("PARAMS_RAGGED_RANK", PARAMS_RAGGED_RANK) + .Attr("OUTPUT_RAGGED_RANK", num_splits) + .Attr("Tvalues", value_dtype) + .Attr("Tindices", index_dtype) + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + for (const auto& splits : params_nested_splits) { + int64 splits_size = splits.size(); + AddInputFromArray(TensorShape({splits_size}), splits); + } + AddInputFromArray(params_dense_values_shape, + params_dense_values); + AddInputFromArray(indices_shape, indices); + } +}; + +TEST_F(RaggedGatherOpTest, RaggedGather) { + // indices = [2, 1, 0, 3] + // params = [[.1, .2, .3], [], [.4, .5, .6, .7], [.8, .9]] + // params.shape = [4, None] + BuildRaggedGatherGraph( + TensorShape({4}), // indices.shape + {2, 1, 0, 3}, // indices + {{0, 3, 3, 7, 9}}, // params_nested_splits + TensorShape({9}), // params_dense_values.shape + {.1, .2, .3, .4, .5, .6, .7, .8, .9} // params_dense_values + ); + + TF_ASSERT_OK(RunOpKernel()); + + // Expected: [[.4, .5, .6, .7], [.1, .2, .3], [], [.8, .9]] + test::ExpectTensorEqual(*GetOutput(0), + test::AsTensor({0, 4, 4, 7, 9})); + test::ExpectTensorNear( + *GetOutput(1), + test::AsTensor({.4, .5, .6, .7, .1, .2, .3, .8, .9}), 0.1); +} + +TEST_F(RaggedGatherOpTest, RaggedGather_3DParams) { + // indices = [2, 1, 0, 2, 3] + // params = [[[]], [[.1, 2], [.3]], [], [[.4, .5], [.6, .7, .8]], [[.9]]] + // params.shape = [5, None, None] + BuildRaggedGatherGraph( + TensorShape({5}), // indices.shape + {2, 1, 0, 2, 3}, // indices + {{0, 1, 3, 3, 5, 6}, {0, 0, 2, 3, 5, 8, 9}}, // params_nested_splits + TensorShape({9}), // params_dense_values.shape + {.1, .2, .3, .4, .5, .6, .7, .8, .9} // params_dense_values + ); + + TF_ASSERT_OK(RunOpKernel()); + + // Expected: [[], [[.1, 2], [.3]], [[]], [], [[.4, .5], [.6, .7, .8]]] + test::ExpectTensorEqual(*GetOutput(0), + test::AsTensor({0, 0, 2, 3, 3, 5})); + test::ExpectTensorEqual(*GetOutput(1), + test::AsTensor({0, 2, 3, 3, 5, 8})); + test::ExpectTensorNear( + *GetOutput(2), test::AsTensor({.1, .2, .3, .4, .5, .6, .7, .8}), + 0.1); +} + +TEST_F(RaggedGatherOpTest, RaggedGather_4DParams) { + // indices = [2, 1, 0, 2] + // params = [[[]], [[[1, 2], [3, 4], [5, 6]], [[7, 8]]], []] + // params.shape = [4, None, None, 2] + BuildRaggedGatherGraph( + TensorShape({4}), // indices.shape + {2, 1, 0, 2}, // indices + {{0, 1, 3, 3}, {0, 0, 3, 4}}, // params_nested_splits + TensorShape({4, 2}), // params_dense_values.shape + {1, 2, 3, 4, 5, 6, 7, 8} // params_dense_values + ); + + TF_ASSERT_OK(RunOpKernel()); + + // Expected: [[], + // [[[1, 2], [3, 4], [5, 6]], [[7, 8]]], + // [[]], + // []] + test::ExpectTensorEqual(*GetOutput(0), + test::AsTensor({0, 0, 2, 3, 3})); + test::ExpectTensorEqual(*GetOutput(1), + test::AsTensor({0, 3, 4, 4})); + test::ExpectTensorEqual( + *GetOutput(2), + test::AsTensor({1, 2, 3, 4, 5, 6, 7, 8}, TensorShape({4, 2}))); +} + +TEST_F(RaggedGatherOpTest, RaggedGather_2DIndices) { + // indices = [[2, 1], [0, 3]] + // params = [[.1, .2, .3], [], [.4, .5, .6, .7], [.8, .9]] + BuildRaggedGatherGraph( + TensorShape({2, 2}), // indices.shape + {2, 1, 0, 3}, // indices + {{0, 3, 3, 7, 9}}, // params_nested_splits + TensorShape({9}), // params_dense_values.shape + {.1, .2, .3, .4, .5, .6, .7, .8, .9} // params_dense_values + ); + + TF_ASSERT_OK(RunOpKernel()); + + // Expected: [ [ [.4, .5, .6, .7], [.1, .2, .3] ], + // [ [], [.8, .9] ] ] + test::ExpectTensorEqual(*GetOutput(0), + test::AsTensor({0, 2, 4})); + test::ExpectTensorEqual(*GetOutput(1), + test::AsTensor({0, 4, 4, 7, 9})); + test::ExpectTensorNear( + *GetOutput(2), + test::AsTensor({.4, .5, .6, .7, .1, .2, .3, .8, .9}), 0.1); +} + +TEST_F(RaggedGatherOpTest, RaggedGather_ScalarIndices) { + // indices = 2 + // params = [[.1, .2, .3], [], [.4, .5, .6, .7], [.8, .9]] + BuildRaggedGatherGraph( + TensorShape({}), // indices.shape + {2}, // indices + {{0, 3, 3, 7, 9}}, // params_nested_splits + TensorShape({9}), // params_dense_values.shape + {.1, .2, .3, .4, .5, .6, .7, .8, .9} // params_dense_values + ); + TF_ASSERT_OK(RunOpKernel()); + + // Expected: [.4, .5, .6, .7] + test::ExpectTensorNear(*GetOutput(0), + test::AsTensor({.4, .5, .6, .7}), 0.1); +} + +TEST_F(RaggedGatherOpTest, RaggedGather_OutOfBounds) { + // indices = [2, 10] + // params = [[.1, .2, .3], [], [.4, .5, .6, .7], [.8, .9]] + BuildRaggedGatherGraph( + TensorShape({2}), // indices.shape + {2, 10}, // indices + {{0, 3, 3, 7, 9}}, // params_nested_splits + TensorShape({9}), // params_dense_values.shape + {.1, .2, .3, .4, .5, .6, .7, .8, .9} // params_dense_values + ); + EXPECT_EQ("indices[1] = 10 is not in [0, 4)", RunOpKernel().error_message()); +} + +TEST_F(RaggedGatherOpTest, InvalidSplitsNotSorted) { + BuildRaggedGatherGraph( + TensorShape({2}), // indices.shape + {0, 2}, // indices + {{0, 3, 5, 2, 9}}, // params_nested_splits + TensorShape({9}), // params_dense_values.shape + {.1, .2, .3, .4, .5, .6, .7, .8, .9} // params_dense_values + ); + EXPECT_EQ("Ragged splits must be sorted", RunOpKernel().error_message()); +} + +TEST_F(RaggedGatherOpTest, InvalidSplitsNegative) { + BuildRaggedGatherGraph( + TensorShape({2}), // indices.shape + {0, 2}, // indices + {{-1, 3, 2, 7, 9}}, // params_nested_splits + TensorShape({9}), // params_dense_values.shape + {.1, .2, .3, .4, .5, .6, .7, .8, .9} // params_dense_values + ); + EXPECT_EQ("Ragged splits must be non-negative", + RunOpKernel().error_message()); +} + +TEST_F(RaggedGatherOpTest, InvalidSplitsEmpty) { + BuildRaggedGatherGraph( + TensorShape({0}), // indices.shape + {}, // indices + {{}}, // params_nested_splits + TensorShape({0}), // params_dense_values.shape + {} // params_dense_values + ); + EXPECT_EQ("Ragged splits may not be empty", RunOpKernel().error_message()); +} + +TEST_F(RaggedGatherOpTest, InvalidSplitsTooBig) { + BuildRaggedGatherGraph( + TensorShape({2}), // indices.shape + {0, 2}, // indices + {{0, 20, 40, 80, 100}}, // params_nested_splits + TensorShape({9}), // params_dense_values.shape + {.1, .2, .3, .4, .5, .6, .7, .8, .9} // params_dense_values + ); + EXPECT_EQ("Ragged splits must not point past values", + RunOpKernel().error_message()); +} + +TEST_F(RaggedGatherOpTest, BadValuesShape) { + BuildRaggedGatherGraph( + TensorShape({0}), // indices.shape + {}, // indices + {{0}}, // params_nested_splits + TensorShape({}), // params_dense_values.shape + {.1} // params_dense_values + ); + EXPECT_EQ("params.rank must be nonzero", RunOpKernel().error_message()); +} + +TEST_F(RaggedGatherOpTest, ShapeFn) { + // RaggedGather(param_splits+, param_values, indices) -> [splits+, values] + ShapeInferenceTestOp op("RaggedGather"); + + (*op.node_def.mutable_attr())["PARAMS_RAGGED_RANK"].set_i(1); + (*op.node_def.mutable_attr())["OUTPUT_RAGGED_RANK"].set_i(1); + INFER_OK(op, "?;?;?", "[?];?"); + INFER_OK(op, "[?];[?];[?]", "[?];[?]"); + INFER_OK(op, "[?];[?,?,?];[?]", "[?];[?,d1_1,d1_2]"); + INFER_OK(op, "[5];[10];[15]", "[?];[?]"); + INFER_OK(op, "[5];[10,2];[15]", "[?];[?,d1_1]"); + INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[5];[];[]"); + INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[1,2];[];[5]"); + + (*op.node_def.mutable_attr())["PARAMS_RAGGED_RANK"].set_i(2); + (*op.node_def.mutable_attr())["OUTPUT_RAGGED_RANK"].set_i(2); + INFER_OK(op, "?;?;?;?", "[?];[?];?"); + INFER_OK(op, "[?];[?];[?];[?]", "[?];[?];[?]"); + INFER_OK(op, "[?];[?];[?,?,?];[?]", "[?];[?];[?,d2_1,d2_2]"); + INFER_OK(op, "[5];[10];[15];[20]", "[?];[?];[?]"); + + (*op.node_def.mutable_attr())["PARAMS_RAGGED_RANK"].set_i(1); + (*op.node_def.mutable_attr())["OUTPUT_RAGGED_RANK"].set_i(2); + INFER_OK(op, "?;?;?", "[?];[?];?"); + INFER_OK(op, "[?];[?];[?,?]", "[?];[?];[?]"); + INFER_OK(op, "[?];[?,?,?];[?,?]", "[?];[?];[?,d1_1,d1_2]"); + INFER_OK(op, "[15];[20];[5,10]", "[?];[?];[?]"); + INFER_OK(op, "[15];[20,2];[5,10]", "[?];[?];[?,d1_1]"); + + (*op.node_def.mutable_attr())["PARAMS_RAGGED_RANK"].set_i(1); + (*op.node_def.mutable_attr())["OUTPUT_RAGGED_RANK"].set_i(0); + INFER_OK(op, "[?];[?];[]", "[?]"); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/ops/ragged_array_ops.cc b/tensorflow/core/ops/ragged_array_ops.cc new file mode 100644 index 0000000000..4642579939 --- /dev/null +++ b/tensorflow/core/ops/ragged_array_ops.cc @@ -0,0 +1,85 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +using shape_inference::DimensionHandle; +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; + +Status RaggedGatherShapeFn(InferenceContext* c); + +//============================================================================== +// Registered Ops +//============================================================================== + +REGISTER_OP("RaggedGather") + .Input("params_nested_splits: PARAMS_RAGGED_RANK * int64") + .Input("params_dense_values: Tvalues") + .Input("indices: Tindices") + .Output("output_nested_splits: OUTPUT_RAGGED_RANK * int64") + .Output("output_dense_values: Tvalues") + .Attr("Tvalues: type") + .Attr("Tindices: {int32, int64}") + .Attr("PARAMS_RAGGED_RANK: int >= 1") + .Attr("OUTPUT_RAGGED_RANK: int >= 0") + .SetShapeFn(RaggedGatherShapeFn); + +//============================================================================== +// Shape Functions +//============================================================================== + +Status RaggedGatherShapeFn(InferenceContext* c) { + int num_splits; + int64 PARAMS_RAGGED_RANK; + TF_RETURN_IF_ERROR( + c->GetAttr("PARAMS_RAGGED_RANK", &PARAMS_RAGGED_RANK)); + TF_RETURN_IF_ERROR(c->GetAttr("OUTPUT_RAGGED_RANK", &num_splits)); + + // Check rank of `indices`. + ShapeHandle indices = c->input(PARAMS_RAGGED_RANK + 1); + TF_RETURN_IF_ERROR( + c->WithRank(indices, num_splits - PARAMS_RAGGED_RANK + 1, &indices)); + + // Check that all params_nested_splits have rank 1. + for (int64 i = 0; i < PARAMS_RAGGED_RANK; ++i) { + ShapeHandle splits = c->input(i); + TF_RETURN_IF_ERROR(c->WithRank(splits, 1, &splits)); + } + + // Check that `params_dense_values` has rank>=1. + ShapeHandle params_dense_values = c->input(PARAMS_RAGGED_RANK); + TF_RETURN_IF_ERROR( + c->WithRankAtLeast(params_dense_values, 1, ¶ms_dense_values)); + + // Set the rank for the `splits` outputs. + for (int i = 0; i < num_splits; ++i) { + c->set_output(i, c->UnknownShapeOfRank(1)); + } + + // Calculate the `values` shape. + ShapeHandle value = c->UnknownShape(); + ShapeHandle values = c->UnknownShape(); + TF_RETURN_IF_ERROR(c->Subshape(params_dense_values, 1, &value)); + TF_RETURN_IF_ERROR(c->Concatenate(c->UnknownShapeOfRank(1), value, &values)); + c->set_output(num_splits, values); + + return Status::OK(); +} + +} // namespace tensorflow -- cgit v1.2.3