aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/data/dataset_utils_test.cc
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-10-08 10:14:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-08 10:23:50 -0700
commit0e42fd6d0a88b30ab57959f38c79bea19d745ec3 (patch)
tree177a8662e421aa2c4d57d1b1caf370077de0b55c /tensorflow/core/kernels/data/dataset_utils_test.cc
parent0e1ba8886b6a333b1ed8ed7548c55041c34e9623 (diff)
[tf.data] Adding specialization for `MapDataset`, `ParallelMapDataset`, and `MapAndBatchDataset` whose user-provided functions have the property that each output argument take its value directly from an input argument (e.g. `lambda x, y: y, x`). This specialization can produce the result without having to schedule the function using the executor.
PiperOrigin-RevId: 216206232
Diffstat (limited to 'tensorflow/core/kernels/data/dataset_utils_test.cc')
-rw-r--r--tensorflow/core/kernels/data/dataset_utils_test.cc46
1 files changed, 46 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/data/dataset_utils_test.cc b/tensorflow/core/kernels/data/dataset_utils_test.cc
new file mode 100644
index 0000000000..43295b8ebb
--- /dev/null
+++ b/tensorflow/core/kernels/data/dataset_utils_test.cc
@@ -0,0 +1,46 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/data/dataset_utils.h"
+
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+TEST(DatasetUtils, ComputeMoveVector) {
+ struct TestCase {
+ std::vector<int> indices;
+ std::vector<bool> expected;
+ };
+
+ TestCase test_cases[] = {
+ TestCase{{}, {}},
+ TestCase{{1}, {true}},
+ TestCase{{1, 1}, {false, true}},
+ TestCase{{1, 2}, {true, true}},
+ TestCase{{1, 1, 2}, {false, true, true}},
+ TestCase{{1, 2, 2}, {true, false, true}},
+ };
+
+ for (auto& test_case : test_cases) {
+ EXPECT_EQ(test_case.expected, ComputeMoveVector(test_case.indices));
+ }
+}
+
+} // namespace
+} // namespace data
+} // namespace tensorflow