/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/shape_inference_testutil.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/kernels/ops_testutil.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { namespace { class RaggedGatherOpTest : public ::tensorflow::OpsTestBase { protected: // Builds the tensorflow test graph for RaggedGather. template void BuildRaggedGatherGraph( const TensorShape& indices_shape, const std::vector& indices, const std::vector>& params_nested_splits, const TensorShape& params_dense_values_shape, const gtl::ArraySlice params_dense_values) { const auto& value_dtype = DataTypeToEnum::v(); const auto& index_dtype = DataTypeToEnum::v(); int64 PARAMS_RAGGED_RANK = params_nested_splits.size(); int64 num_splits = PARAMS_RAGGED_RANK + indices_shape.dims() - 1; TF_ASSERT_OK( NodeDefBuilder("tested_op", "RaggedGather") .Input(FakeInput(PARAMS_RAGGED_RANK)) // params_nested_splits .Input(FakeInput(value_dtype)) // params_dense_values .Input(FakeInput(index_dtype)) // indices .Attr("PARAMS_RAGGED_RANK", PARAMS_RAGGED_RANK) .Attr("OUTPUT_RAGGED_RANK", num_splits) .Attr("Tvalues", value_dtype) .Attr("Tindices", index_dtype) .Finalize(node_def())); TF_ASSERT_OK(InitOp()); for (const auto& splits : params_nested_splits) { int64 splits_size = splits.size(); AddInputFromArray(TensorShape({splits_size}), splits); } AddInputFromArray(params_dense_values_shape, params_dense_values); AddInputFromArray(indices_shape, indices); } }; TEST_F(RaggedGatherOpTest, RaggedGather) { // indices = [2, 1, 0, 3] // params = [[.1, .2, .3], [], [.4, .5, .6, .7], [.8, .9]] // params.shape = [4, None] BuildRaggedGatherGraph( TensorShape({4}), // indices.shape {2, 1, 0, 3}, // indices {{0, 3, 3, 7, 9}}, // params_nested_splits TensorShape({9}), // params_dense_values.shape {.1, .2, .3, .4, .5, .6, .7, .8, .9} // params_dense_values ); TF_ASSERT_OK(RunOpKernel()); // Expected: [[.4, .5, .6, .7], [.1, .2, .3], [], [.8, .9]] test::ExpectTensorEqual(*GetOutput(0), test::AsTensor({0, 4, 4, 7, 9})); test::ExpectTensorNear( *GetOutput(1), test::AsTensor({.4, .5, .6, .7, .1, .2, .3, .8, .9}), 0.1); } TEST_F(RaggedGatherOpTest, RaggedGather_3DParams) { // indices = [2, 1, 0, 2, 3] // params = [[[]], [[.1, 2], [.3]], [], [[.4, .5], [.6, .7, .8]], [[.9]]] // params.shape = [5, None, None] BuildRaggedGatherGraph( TensorShape({5}), // indices.shape {2, 1, 0, 2, 3}, // indices {{0, 1, 3, 3, 5, 6}, {0, 0, 2, 3, 5, 8, 9}}, // params_nested_splits TensorShape({9}), // params_dense_values.shape {.1, .2, .3, .4, .5, .6, .7, .8, .9} // params_dense_values ); TF_ASSERT_OK(RunOpKernel()); // Expected: [[], [[.1, 2], [.3]], [[]], [], [[.4, .5], [.6, .7, .8]]] test::ExpectTensorEqual(*GetOutput(0), test::AsTensor({0, 0, 2, 3, 3, 5})); test::ExpectTensorEqual(*GetOutput(1), test::AsTensor({0, 2, 3, 3, 5, 8})); test::ExpectTensorNear( *GetOutput(2), test::AsTensor({.1, .2, .3, .4, .5, .6, .7, .8}), 0.1); } TEST_F(RaggedGatherOpTest, RaggedGather_4DParams) { // indices = [2, 1, 0, 2] // params = [[[]], [[[1, 2], [3, 4], [5, 6]], [[7, 8]]], []] // params.shape = [4, None, None, 2] BuildRaggedGatherGraph( TensorShape({4}), // indices.shape {2, 1, 0, 2}, // indices {{0, 1, 3, 3}, {0, 0, 3, 4}}, // params_nested_splits TensorShape({4, 2}), // params_dense_values.shape {1, 2, 3, 4, 5, 6, 7, 8} // params_dense_values ); TF_ASSERT_OK(RunOpKernel()); // Expected: [[], // [[[1, 2], [3, 4], [5, 6]], [[7, 8]]], // [[]], // []] test::ExpectTensorEqual(*GetOutput(0), test::AsTensor({0, 0, 2, 3, 3})); test::ExpectTensorEqual(*GetOutput(1), test::AsTensor({0, 3, 4, 4})); test::ExpectTensorEqual( *GetOutput(2), test::AsTensor({1, 2, 3, 4, 5, 6, 7, 8}, TensorShape({4, 2}))); } TEST_F(RaggedGatherOpTest, RaggedGather_2DIndices) { // indices = [[2, 1], [0, 3]] // params = [[.1, .2, .3], [], [.4, .5, .6, .7], [.8, .9]] BuildRaggedGatherGraph( TensorShape({2, 2}), // indices.shape {2, 1, 0, 3}, // indices {{0, 3, 3, 7, 9}}, // params_nested_splits TensorShape({9}), // params_dense_values.shape {.1, .2, .3, .4, .5, .6, .7, .8, .9} // params_dense_values ); TF_ASSERT_OK(RunOpKernel()); // Expected: [ [ [.4, .5, .6, .7], [.1, .2, .3] ], // [ [], [.8, .9] ] ] test::ExpectTensorEqual(*GetOutput(0), test::AsTensor({0, 2, 4})); test::ExpectTensorEqual(*GetOutput(1), test::AsTensor({0, 4, 4, 7, 9})); test::ExpectTensorNear( *GetOutput(2), test::AsTensor({.4, .5, .6, .7, .1, .2, .3, .8, .9}), 0.1); } TEST_F(RaggedGatherOpTest, RaggedGather_ScalarIndices) { // indices = 2 // params = [[.1, .2, .3], [], [.4, .5, .6, .7], [.8, .9]] BuildRaggedGatherGraph( TensorShape({}), // indices.shape {2}, // indices {{0, 3, 3, 7, 9}}, // params_nested_splits TensorShape({9}), // params_dense_values.shape {.1, .2, .3, .4, .5, .6, .7, .8, .9} // params_dense_values ); TF_ASSERT_OK(RunOpKernel()); // Expected: [.4, .5, .6, .7] test::ExpectTensorNear(*GetOutput(0), test::AsTensor({.4, .5, .6, .7}), 0.1); } TEST_F(RaggedGatherOpTest, RaggedGather_OutOfBounds) { // indices = [2, 10] // params = [[.1, .2, .3], [], [.4, .5, .6, .7], [.8, .9]] BuildRaggedGatherGraph( TensorShape({2}), // indices.shape {2, 10}, // indices {{0, 3, 3, 7, 9}}, // params_nested_splits TensorShape({9}), // params_dense_values.shape {.1, .2, .3, .4, .5, .6, .7, .8, .9} // params_dense_values ); EXPECT_EQ("indices[1] = 10 is not in [0, 4)", RunOpKernel().error_message()); } TEST_F(RaggedGatherOpTest, InvalidSplitsNotSorted) { BuildRaggedGatherGraph( TensorShape({2}), // indices.shape {0, 2}, // indices {{0, 3, 5, 2, 9}}, // params_nested_splits TensorShape({9}), // params_dense_values.shape {.1, .2, .3, .4, .5, .6, .7, .8, .9} // params_dense_values ); EXPECT_EQ("Ragged splits must be sorted", RunOpKernel().error_message()); } TEST_F(RaggedGatherOpTest, InvalidSplitsNegative) { BuildRaggedGatherGraph( TensorShape({2}), // indices.shape {0, 2}, // indices {{-1, 3, 2, 7, 9}}, // params_nested_splits TensorShape({9}), // params_dense_values.shape {.1, .2, .3, .4, .5, .6, .7, .8, .9} // params_dense_values ); EXPECT_EQ("Ragged splits must be non-negative", RunOpKernel().error_message()); } TEST_F(RaggedGatherOpTest, InvalidSplitsEmpty) { BuildRaggedGatherGraph( TensorShape({0}), // indices.shape {}, // indices {{}}, // params_nested_splits TensorShape({0}), // params_dense_values.shape {} // params_dense_values ); EXPECT_EQ("Ragged splits may not be empty", RunOpKernel().error_message()); } TEST_F(RaggedGatherOpTest, InvalidSplitsTooBig) { BuildRaggedGatherGraph( TensorShape({2}), // indices.shape {0, 2}, // indices {{0, 20, 40, 80, 100}}, // params_nested_splits TensorShape({9}), // params_dense_values.shape {.1, .2, .3, .4, .5, .6, .7, .8, .9} // params_dense_values ); EXPECT_EQ("Ragged splits must not point past values", RunOpKernel().error_message()); } TEST_F(RaggedGatherOpTest, BadValuesShape) { BuildRaggedGatherGraph( TensorShape({0}), // indices.shape {}, // indices {{0}}, // params_nested_splits TensorShape({}), // params_dense_values.shape {.1} // params_dense_values ); EXPECT_EQ("params.rank must be nonzero", RunOpKernel().error_message()); } TEST_F(RaggedGatherOpTest, ShapeFn) { // RaggedGather(param_splits+, param_values, indices) -> [splits+, values] ShapeInferenceTestOp op("RaggedGather"); (*op.node_def.mutable_attr())["PARAMS_RAGGED_RANK"].set_i(1); (*op.node_def.mutable_attr())["OUTPUT_RAGGED_RANK"].set_i(1); INFER_OK(op, "?;?;?", "[?];?"); INFER_OK(op, "[?];[?];[?]", "[?];[?]"); INFER_OK(op, "[?];[?,?,?];[?]", "[?];[?,d1_1,d1_2]"); INFER_OK(op, "[5];[10];[15]", "[?];[?]"); INFER_OK(op, "[5];[10,2];[15]", "[?];[?,d1_1]"); INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[5];[];[]"); INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[1,2];[];[5]"); (*op.node_def.mutable_attr())["PARAMS_RAGGED_RANK"].set_i(2); (*op.node_def.mutable_attr())["OUTPUT_RAGGED_RANK"].set_i(2); INFER_OK(op, "?;?;?;?", "[?];[?];?"); INFER_OK(op, "[?];[?];[?];[?]", "[?];[?];[?]"); INFER_OK(op, "[?];[?];[?,?,?];[?]", "[?];[?];[?,d2_1,d2_2]"); INFER_OK(op, "[5];[10];[15];[20]", "[?];[?];[?]"); (*op.node_def.mutable_attr())["PARAMS_RAGGED_RANK"].set_i(1); (*op.node_def.mutable_attr())["OUTPUT_RAGGED_RANK"].set_i(2); INFER_OK(op, "?;?;?", "[?];[?];?"); INFER_OK(op, "[?];[?];[?,?]", "[?];[?];[?]"); INFER_OK(op, "[?];[?,?,?];[?,?]", "[?];[?];[?,d1_1,d1_2]"); INFER_OK(op, "[15];[20];[5,10]", "[?];[?];[?]"); INFER_OK(op, "[15];[20,2];[5,10]", "[?];[?];[?,d1_1]"); (*op.node_def.mutable_attr())["PARAMS_RAGGED_RANK"].set_i(1); (*op.node_def.mutable_attr())["OUTPUT_RAGGED_RANK"].set_i(0); INFER_OK(op, "[?];[?];[]", "[?]"); } } // namespace } // namespace tensorflow