aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/hadoop
diff options
context:
space:
mode:
authorGravatar Yong Tang <yong.tang.github@outlook.com>2018-05-23 19:03:45 +0000
committerGravatar Yong Tang <yong.tang.github@outlook.com>2018-06-29 21:53:55 +0000
commit01e125d0ccce72948aa9a0f80d1fc6e343e81a10 (patch)
tree316ac76015242513a984428ee0a03520f9680a69 /tensorflow/contrib/hadoop
parentfa9f9a0f76c86bfabd2e63db2add2829a749f639 (diff)
Add test case for hadoop SequenceFileDataset
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Diffstat (limited to 'tensorflow/contrib/hadoop')
-rw-r--r--tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py66
-rwxr-xr-xtensorflow/contrib/hadoop/python/kernel_tests/testdata/string.seqbin0 -> 603 bytes
2 files changed, 66 insertions, 0 deletions
diff --git a/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py b/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py
new file mode 100644
index 0000000000..8bbb1da85d
--- /dev/null
+++ b/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py
@@ -0,0 +1,66 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+# ==============================================================================
+"""Tests for SequenceFileDataset."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.contrib.hadoop.python.ops import hadoop_dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import resource_loader
+from tensorflow.python.platform import test
+
+class SequenceFileDatasetTest(test.TestCase):
+
+ def test_sequence_file_dataset(self):
+ """Test case for SequenceFileDataset.
+
+ The file is generated with `org.apache.hadoop.io.Text` for key/value.
+ There are 25 records in the file with the format of:
+ key = XXX
+ value = VALUEXXX
+ where XXX is replaced as the line number (starts with 001).
+ """
+ filename = os.path.join(
+ resource_loader.get_data_files_path(), 'testdata/string.seq')
+
+ filenames = array_ops.placeholder_with_default(
+ constant_op.constant([filename], dtypes.string), shape=[None])
+ num_repeats = 2
+
+ dataset = hadoop_dataset_ops.SequenceFileDataset(
+ filenames).repeat(num_repeats)
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(init_op)
+ for _ in range(num_repeats): # Dataset is repeated.
+ for i in range(25): # 25 records.
+ v0 = "{0:03d}".format(i + 1)
+ v1 = "VALUE{0:03d}".format(i + 1)
+ self.assertEqual((v0, v1), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/hadoop/python/kernel_tests/testdata/string.seq b/tensorflow/contrib/hadoop/python/kernel_tests/testdata/string.seq
new file mode 100755
index 0000000000..b7175338af
--- /dev/null
+++ b/tensorflow/contrib/hadoop/python/kernel_tests/testdata/string.seq
Binary files differ