diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/scatter_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/scatter_test.cc | 172 |
1 files changed, 76 insertions, 96 deletions
diff --git a/tensorflow/compiler/xla/tests/scatter_test.cc b/tensorflow/compiler/xla/tests/scatter_test.cc index 1858dcea61..d20dba028a 100644 --- a/tensorflow/compiler/xla/tests/scatter_test.cc +++ b/tensorflow/compiler/xla/tests/scatter_test.cc @@ -62,13 +62,11 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr<Literal> operand = + Literal operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> scatter_indices = - LiteralUtil::CreateR1<int32>({0, 2}); - std::unique_ptr<Literal> updates = - LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2}); + Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, TensorFlowScatterV2_Update) { @@ -92,13 +90,12 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr<Literal> operand = + Literal operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> scatter_indices = - LiteralUtil::CreateR1<int32>({0, 2}); - std::unique_ptr<Literal> updates = + Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2}); + Literal updates = LiteralUtil::CreateR2<int32>({{10, 30}, {40, 60}, {70, 90}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, TensorFlowScatter_Add) { @@ -123,13 +120,11 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr<Literal> operand = + Literal operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> scatter_indices = - LiteralUtil::CreateR1<int32>({0, 2}); - std::unique_ptr<Literal> updates = - LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2}); + Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, TensorFlowScatter_Mul) { @@ -154,13 +149,11 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr<Literal> operand = + Literal operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> scatter_indices = - LiteralUtil::CreateR1<int32>({0, 2}); - std::unique_ptr<Literal> updates = - LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2}); + Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, TensorFlowScatter_F32) { @@ -185,13 +178,12 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<float>( + Literal operand = LiteralUtil::CreateR2<float>( {{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}}); - std::unique_ptr<Literal> scatter_indices = - LiteralUtil::CreateR1<int32>({2, 1}); - std::unique_ptr<Literal> updates = + Literal scatter_indices = LiteralUtil::CreateR1<int32>({2, 1}); + Literal updates = LiteralUtil::CreateR2<float>({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, TensorFlowScatter_RepeatedIndices) { @@ -216,13 +208,11 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr<Literal> operand = + Literal operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> scatter_indices = - LiteralUtil::CreateR1<int32>({1, 1}); - std::unique_ptr<Literal> updates = - LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal scatter_indices = LiteralUtil::CreateR1<int32>({1, 1}); + Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, TensorFlowScatter_MultipleBatchDims) { @@ -247,13 +237,12 @@ ENTRY main { index_vector_dim=2 } )"; - std::unique_ptr<Literal> operand = + Literal operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> scatter_indices = - LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}}); - std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>( + Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}}); + Literal updates = LiteralUtil::CreateR3<int32>( {{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, TensorFlowScatterNd) { @@ -277,15 +266,13 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr<Literal> operand = + Literal operand = LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr<Literal> scatter_indices = - LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}}); - std::unique_ptr<Literal> updates = - LiteralUtil::CreateR2<int32>({{-10, 10}, {-40, 40}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}}); + Literal updates = LiteralUtil::CreateR2<int32>({{-10, 10}, {-40, 40}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, TensorFlowScatterNd_NonDefaultIndexVectorDim) { @@ -309,15 +296,13 @@ ENTRY main { index_vector_dim=0 } )"; - std::unique_ptr<Literal> operand = + Literal operand = LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr<Literal> scatter_indices = - LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}}); - std::unique_ptr<Literal> updates = - LiteralUtil::CreateR2<int32>({{-10, 10}, {-20, 20}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}}); + Literal updates = LiteralUtil::CreateR2<int32>({{-10, 10}, {-20, 20}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, DynamicUpdateSlice) { @@ -341,12 +326,11 @@ ENTRY main { index_vector_dim=0 } )"; - std::unique_ptr<Literal> operand = + Literal operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> scatter_indices = - LiteralUtil::CreateR1<int32>({1, 1}); - std::unique_ptr<Literal> updates = LiteralUtil::CreateR2<int32>({{10}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal scatter_indices = LiteralUtil::CreateR1<int32>({1, 1}); + Literal updates = LiteralUtil::CreateR2<int32>({{10}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, BatchDynamicUpdateSlice) { @@ -370,13 +354,11 @@ ENTRY main { index_vector_dim=0 } )"; - std::unique_ptr<Literal> operand = + Literal operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> scatter_indices = - LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}}); - std::unique_ptr<Literal> updates = - LiteralUtil::CreateR3<int32>({{{10}}, {{20}}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal scatter_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}}); + Literal updates = LiteralUtil::CreateR3<int32>({{{10}}, {{20}}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, ZeroDimBounds) { @@ -400,11 +382,10 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}}); - std::unique_ptr<Literal> scatter_indices = - LiteralUtil::CreateR1<int32>({0, 2}); - std::unique_ptr<Literal> updates = LiteralUtil::CreateR2<int32>({{}, {}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal operand = LiteralUtil::CreateR2<int32>({{}, {}, {}}); + Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2}); + Literal updates = LiteralUtil::CreateR2<int32>({{}, {}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, NoUpdateWindowDims) { @@ -429,12 +410,11 @@ ENTRY main { index_vector_dim=2 } )"; - std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({0, 1, 2}); - std::unique_ptr<Literal> scatter_indices = + Literal operand = LiteralUtil::CreateR1<int32>({0, 1, 2}); + Literal scatter_indices = LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}}); - std::unique_ptr<Literal> updates = - LiteralUtil::CreateR2<int32>({{10, 20}, {30, 40}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal updates = LiteralUtil::CreateR2<int32>({{10, 20}, {30, 40}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, OutOfBoundsIndex) { @@ -458,13 +438,13 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr<Literal> operand = + Literal operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR2<int32>( + Literal scatter_indices = LiteralUtil::CreateR2<int32>( {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}}); - std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>( + Literal updates = LiteralUtil::CreateR3<int32>( {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, OutOfBoundsUnsignedIndex) { @@ -488,13 +468,13 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr<Literal> operand = + Literal operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR2<uint32>( + Literal scatter_indices = LiteralUtil::CreateR2<uint32>( {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}}); - std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>( + Literal updates = LiteralUtil::CreateR3<int32>( {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, NegativeIndex) { @@ -518,13 +498,13 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr<Literal> operand = + Literal operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR2<int32>( + Literal scatter_indices = LiteralUtil::CreateR2<int32>( {{2, 7}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}}); - std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>( + Literal updates = LiteralUtil::CreateR3<int32>( {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, OneScalarIndex) { @@ -548,12 +528,12 @@ ENTRY main { index_vector_dim=0 } )"; - std::unique_ptr<Literal> operand = LiteralUtil::CreateR3<int32>( + Literal operand = LiteralUtil::CreateR3<int32>( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); - std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR0<int32>(1); - std::unique_ptr<Literal> updates = + Literal scatter_indices = LiteralUtil::CreateR0<int32>(1); + Literal updates = LiteralUtil::CreateR3<int32>({{{10, 20}, {30, 40}, {50, 60}}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, ScalarUpdate) { @@ -577,10 +557,10 @@ ENTRY main { index_vector_dim=0 } )"; - std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({1, 2, 3, 4}); - std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR0<int32>(1); - std::unique_ptr<Literal> updates = LiteralUtil::CreateR0<int32>(25); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal operand = LiteralUtil::CreateR1<int32>({1, 2, 3, 4}); + Literal scatter_indices = LiteralUtil::CreateR0<int32>(1); + Literal updates = LiteralUtil::CreateR0<int32>(25); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, EmptyIndices) { @@ -604,10 +584,10 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({1, 2, 3}); - std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR1<int32>({}); - std::unique_ptr<Literal> updates = LiteralUtil::CreateR1<int32>({}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal operand = LiteralUtil::CreateR1<int32>({1, 2, 3}); + Literal scatter_indices = LiteralUtil::CreateR1<int32>({}); + Literal updates = LiteralUtil::CreateR1<int32>({}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } } // namespace |