aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/scatter_nd_op_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-11-02 15:57:45 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-02 17:10:10 -0700
commit9a66842db8568e90bfef7c3c7836c052b776f17c (patch)
tree678aea429196209d3368dcb5936105efd60acd51 /tensorflow/core/kernels/scatter_nd_op_test.cc
parentf9b6d55ffd630082efd088fca927f7d991fdf3fa (diff)
Adding CPU kernels for tf.scatter_nd(), tf.scatter_nd_update(), tf.scatter_nd_add(), tf.scatter_nd_sub(), tf.scatter_nd_mul() and tf.scatter_nd_div() as well as gradient functions for tf.scatter_nd() and tf.gather_nd()
Change: 138013328
Diffstat (limited to 'tensorflow/core/kernels/scatter_nd_op_test.cc')
-rw-r--r--tensorflow/core/kernels/scatter_nd_op_test.cc320
1 files changed, 320 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/scatter_nd_op_test.cc b/tensorflow/core/kernels/scatter_nd_op_test.cc
new file mode 100644
index 0000000000..d6743a6867
--- /dev/null
+++ b/tensorflow/core/kernels/scatter_nd_op_test.cc
@@ -0,0 +1,320 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+namespace {
+
+class ScatterNdUpdateOpTest : public OpsTestBase {
+ protected:
+ void MakeOp(DataType variable_ref_type, DataType index_type) {
+ TF_ASSERT_OK(NodeDefBuilder("myop", "ScatterNdUpdate")
+ .Input(FakeInput(variable_ref_type))
+ .Input(FakeInput(index_type))
+ .Input(FakeInput(RemoveRefType(variable_ref_type)))
+ .Finalize(node_def()));
+ TF_ASSERT_OK(InitOp());
+ }
+};
+
+TEST_F(ScatterNdUpdateOpTest, Simple_StringType) {
+ MakeOp(DT_STRING_REF, DT_INT32);
+ AddInputFromArray<string>(TensorShape({1}), {"Brain"});
+ AddInputFromArray<int32>(TensorShape({1}), {0});
+ AddInputFromArray<string>(TensorShape({1}), {"TensorFlow"});
+ TF_ASSERT_OK(RunOpKernel());
+ // Check the new state of the input
+ Tensor params_tensor = *mutable_input(0).tensor;
+ Tensor expected(allocator(), DT_STRING, TensorShape({1}));
+ test::FillValues<string>(&expected, {"TensorFlow"});
+ test::ExpectTensorEqual<string>(expected, params_tensor);
+}
+
+TEST_F(ScatterNdUpdateOpTest, Simple_BoolType) {
+ MakeOp(DT_BOOL_REF, DT_INT32);
+ AddInputFromArray<bool>(TensorShape({1}), {false});
+ AddInputFromArray<int32>(TensorShape({1}), {0});
+ AddInputFromArray<bool>(TensorShape({1}), {true});
+ TF_ASSERT_OK(RunOpKernel());
+ // Check the new state of the input
+ Tensor params_tensor = *mutable_input(0).tensor;
+ Tensor expected(allocator(), DT_BOOL, TensorShape({1}));
+ test::FillValues<bool>(&expected, {true});
+ test::ExpectTensorEqual<bool>(expected, params_tensor);
+}
+
+TEST_F(ScatterNdUpdateOpTest, Simple_TwoD32) {
+ MakeOp(DT_FLOAT_REF, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({5, 3}),
+ {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
+ AddInputFromArray<int32>(TensorShape({3, 1}), {0, 4, 2});
+ AddInputFromArray<float>(TensorShape({3, 3}),
+ {100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the new state of the input
+ Tensor params_tensor = *mutable_input(0).tensor;
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 3}));
+ test::FillValues<float>(&expected, {100, 101, 102, 0, 0, 0, 10000, 10001,
+ 10002, 0, 0, 0, 777, 778, 779});
+ test::ExpectTensorEqual<float>(expected, params_tensor);
+}
+
+TEST_F(ScatterNdUpdateOpTest, Simple_Two64) {
+ MakeOp(DT_FLOAT_REF, DT_INT64);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({5, 3}),
+ {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
+ AddInputFromArray<int64>(TensorShape({3, 1}), {0, 4, 2});
+ AddInputFromArray<float>(TensorShape({3, 3}),
+ {100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the new state of the input
+ Tensor params_tensor = *mutable_input(0).tensor;
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 3}));
+ test::FillValues<float>(&expected, {100, 101, 102, 0, 0, 0, 10000, 10001,
+ 10002, 0, 0, 0, 777, 778, 779});
+ test::ExpectTensorEqual<float>(expected, params_tensor);
+}
+/*TEST_F(ScatterNdUpdateOpTest, Simple_ZeroElements) {
+ MakeOp(DT_FLOAT_REF, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({0}), {});
+ AddInputFromArray<int32>(TensorShape({0}), {});
+ AddInputFromArray<float>(TensorShape({0}), {});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(StringPiece(s.ToString())
+ .contains("Output must not have 0 elements, got shape: "))
+ << s;
+}*/
+
+TEST_F(ScatterNdUpdateOpTest, Simple_ZeroD) {
+ MakeOp(DT_FLOAT_REF, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({5}), {0, 0, 0, 0, 0});
+ AddInputFromArray<int32>(TensorShape({1}), {3});
+ AddInputFromArray<float>(TensorShape({1}), {101});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the new state of the input
+ Tensor params_tensor = *mutable_input(0).tensor;
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({5}));
+ test::FillValues<float>(&expected, {0, 0, 0, 101, 0});
+ test::ExpectTensorEqual<float>(expected, params_tensor);
+}
+
+TEST_F(ScatterNdUpdateOpTest, Simple_OneD) {
+ MakeOp(DT_FLOAT_REF, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({5}), {0, 0, 0, 0, 0});
+ AddInputFromArray<int32>(TensorShape({3, 1}), {0, 4, 2});
+ AddInputFromArray<float>(TensorShape({3}), {100, 101, 102});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the new state of the input
+ Tensor params_tensor = *mutable_input(0).tensor;
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({5}));
+ test::FillValues<float>(&expected, {100, 0, 102, 0, 101});
+ test::ExpectTensorEqual<float>(expected, params_tensor);
+}
+
+TEST_F(ScatterNdUpdateOpTest, HigherRank) {
+ MakeOp(DT_FLOAT_REF, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({8}), {0, 0, 0, 0, 0, 0, 0, 0});
+ AddInputFromArray<int32>(TensorShape({2, 3, 1}), {0, 4, 2, 1, 3, 6});
+ AddInputFromArray<float>(TensorShape({2, 3}), {10, 20, 30, 40, 50, 60});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the new state of the input
+ Tensor params_tensor = *mutable_input(0).tensor;
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({8}));
+ test::FillValues<float>(&expected, {10, 40, 30, 50, 20, 0, 60, 0});
+ test::ExpectTensorEqual<float>(expected, params_tensor);
+}
+
+TEST_F(ScatterNdUpdateOpTest, Error_IndexOutOfRange) {
+ MakeOp(DT_FLOAT_REF, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({5, 3}),
+ {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
+ AddInputFromArray<int32>(TensorShape({3, 1}), {0, 4, 99});
+ AddInputFromArray<float>(TensorShape({3, 3}),
+ {100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(StringPiece(s.ToString())
+ .contains("Invalid indices: [2,0] = [99] is not in [0, 5)"))
+ << s;
+}
+
+TEST_F(ScatterNdUpdateOpTest, Error_WrongDimsIndices) {
+ MakeOp(DT_FLOAT_REF, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({2, 3}), {0, 0, 0, 0, 0, 0});
+ AddInputFromArray<int32>(TensorShape({1, 3, 1}), {0, 4, 99});
+ AddInputFromArray<float>(TensorShape({3, 3}),
+ {100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(StringPiece(s.ToString())
+ .contains("The outermost dimension of updates and indices "
+ "must match. Got indices.shape [1,3,1], "
+ "updates.shape [3,3]"))
+ << s;
+}
+
+TEST_F(ScatterNdUpdateOpTest, Error_MismatchedParamsAndUpdateDimensions) {
+ MakeOp(DT_FLOAT_REF, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({5, 3}),
+ {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
+ AddInputFromArray<int32>(TensorShape({3, 1}), {0, 4, 2});
+ AddInputFromArray<float>(
+ TensorShape({3, 4}),
+ {100, 101, 102, 103, 777, 778, 779, 780, 10000, 10001, 10002, 10004});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(StringPiece(s.ToString())
+ .contains("Must have updates.shape = indices.shape[0] + "
+ "params_shape[IXDIM:], got"))
+
+ << s;
+}
+
+TEST_F(ScatterNdUpdateOpTest, Error_MismatchedIndicesAndUpdateDimensions) {
+ MakeOp(DT_FLOAT_REF, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({5, 3}),
+ {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
+ AddInputFromArray<int32>(TensorShape({3, 1}), {0, 4, 2});
+ AddInputFromArray<float>(TensorShape({2, 3}),
+ {100, 101, 102, 10000, 10001, 10002});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(StringPiece(s.ToString())
+ .contains("The outermost dimension of updates and indices "
+ "must match. Got "))
+ << s;
+}
+
+class ScatterNdUpdateBM : public ScatterNdUpdateOpTest {
+ public:
+ virtual void TestBody() {}
+ void MakeBenchmarkOp(const char* op, DataType index_type) {
+ TF_ASSERT_OK(NodeDefBuilder("myop", op)
+ .Input(FakeInput(DT_FLOAT_REF))
+ .Input(FakeInput(index_type))
+ .Input(FakeInput(DT_FLOAT))
+ .Finalize(node_def()));
+ TF_CHECK_OK(InitOp());
+ }
+};
+
+template <typename Index>
+static void BM_ScatterNdHelper(int iters, int embedding_size, const char* op) {
+ testing::StopTiming();
+ const int kRows = 10000000 / embedding_size;
+ std::vector<float> values;
+ values.reserve(kRows);
+ for (int i = 0; i < kRows * embedding_size; i++) {
+ values.push_back(i);
+ }
+ const int kNumUpdates = 1000;
+ random::PhiloxRandom philox(301, 17);
+ random::SimplePhilox rnd(&philox);
+ std::vector<Index> indices;
+ std::vector<float> updates;
+ for (int i = 0; i < kNumUpdates; i++) {
+ indices.push_back(rnd.Uniform(kRows));
+ for (int j = 0; j < embedding_size; j++) {
+ updates.push_back(i * 10 + j);
+ }
+ }
+
+ ScatterNdUpdateBM bm;
+ bm.MakeBenchmarkOp(op, DataTypeToEnum<Index>::v());
+ bm.AddInputFromArray<float>(TensorShape({kRows, embedding_size}), values);
+ bm.AddInputFromArray<Index>(TensorShape({kNumUpdates}), indices);
+ bm.AddInputFromArray<float>(TensorShape({kNumUpdates, embedding_size}),
+ updates);
+ testing::ItemsProcessed((static_cast<int64>(kNumUpdates) * embedding_size) *
+ iters);
+ testing::StartTiming();
+ while (iters-- > 0) {
+ Status s = bm.RunOpKernel();
+ }
+ testing::StopTiming();
+}
+
+static void BM_ScatterNdUpdateInt32(int iters, int embedding_size) {
+ BM_ScatterNdHelper<int32>(iters, embedding_size, "ScatterNdUpdate");
+}
+static void BM_ScatterNdUpdateInt64(int iters, int embedding_size) {
+ BM_ScatterNdHelper<int64>(iters, embedding_size, "ScatterNdUpdate");
+}
+
+static void BM_ScatterNdAddInt32(int iters, int embedding_size) {
+ BM_ScatterNdHelper<int32>(iters, embedding_size, "ScatterNdAdd");
+}
+static void BM_ScatterNdAddInt64(int iters, int embedding_size) {
+ BM_ScatterNdHelper<int64>(iters, embedding_size, "ScatterNdAdd");
+}
+
+BENCHMARK(BM_ScatterNdUpdateInt32)
+ ->Arg(1)
+ ->Arg(10)
+ ->Arg(64)
+ ->Arg(256)
+ ->Arg(1024);
+BENCHMARK(BM_ScatterNdUpdateInt64)
+ ->Arg(1)
+ ->Arg(10)
+ ->Arg(64)
+ ->Arg(256)
+ ->Arg(1024);
+
+BENCHMARK(BM_ScatterNdAddInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
+BENCHMARK(BM_ScatterNdAddInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
+
+} // namespace
+} // namespace tensorflow