aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data
diff options
context:
space:
mode:
authorGravatar Yong Tang <yong.tang.github@outlook.com>2018-08-04 22:25:43 +0000
committerGravatar Yong Tang <yong.tang.github@outlook.com>2018-08-21 00:36:04 +0000
commitb99c3b0db167f2c1719a8d5d3ce3c8b3de867e47 (patch)
tree9b8fe947ad449295ab996c9d90f6790f504c62cc /tensorflow/contrib/data
parent3da376758711410c374329b831a99c483c7d9299 (diff)
Add test cases for LMDBDataset
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD25
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py76
2 files changed, 101 insertions, 0 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index 2b75aa2ca5..74d0a30eee 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -180,6 +180,31 @@ py_test(
)
py_test(
+ name = "lmdb_dataset_op_test",
+ size = "medium",
+ srcs = ["lmdb_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_windows",
+ "no_pip",
+ ],
+ data = ["//tensorflow/core:lmdb_testdata"],
+ deps = [
+ "//tensorflow/contrib/data/python/ops:readers",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:session",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
name = "map_dataset_op_test",
size = "medium",
srcs = ["map_dataset_op_test.py"],
diff --git a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py
new file mode 100644
index 0000000000..5d7d4da113
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py
@@ -0,0 +1,76 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for LMDBDatasetOp."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import shutil
+
+from tensorflow.contrib.data.python.ops import readers
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.platform import resource_loader
+from tensorflow.python.platform import test
+from tensorflow.python.util import compat
+
+class LMDBDatasetTest(test.TestCase):
+
+ def setUp(self):
+ super(LMDBDatasetTest, self).setUp()
+ path = os.path.join(
+ resource_loader.get_root_dir_with_all_resources(),
+ "tensorflow",
+ "core",
+ "lib",
+ "lmdb",
+ "testdata",
+ "data.mdb")
+
+ print(path)
+ # Copy database out because we need the path to be writable to use locks.
+ self.db_path = os.path.join(self.get_temp_dir(), "data.mdb")
+ shutil.copy(path, self.db_path)
+
+ def testReadFromFile(self):
+ filename = self.db_path
+
+ filenames = constant_op.constant([filename], dtypes.string)
+ num_repeats = 2
+
+ dataset = readers.LMDBDataset(
+ filenames).repeat(num_repeats)
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(init_op)
+ for _ in range(num_repeats): # Dataset is repeated.
+ for i in range(10): # 10 records.
+ k = compat.as_bytes(str(i))
+ v = compat.as_bytes(str(chr(ord("a") + i)))
+ self.assertEqual((k, v), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+if __name__ == "__main__":
+ test.main()