From b99c3b0db167f2c1719a8d5d3ce3c8b3de867e47 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 4 Aug 2018 22:25:43 +0000 Subject: Add test cases for LMDBDataset Signed-off-by: Yong Tang --- tensorflow/contrib/data/python/kernel_tests/BUILD | 25 +++++++ .../python/kernel_tests/lmdb_dataset_op_test.py | 76 ++++++++++++++++++++++ 2 files changed, 101 insertions(+) create mode 100644 tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py (limited to 'tensorflow/contrib/data') diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 2b75aa2ca5..74d0a30eee 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -179,6 +179,31 @@ py_test( ], ) +py_test( + name = "lmdb_dataset_op_test", + size = "medium", + srcs = ["lmdb_dataset_op_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_windows", + "no_pip", + ], + data = ["//tensorflow/core:lmdb_testdata"], + deps = [ + "//tensorflow/contrib/data/python/ops:readers", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:platform", + "//tensorflow/python:platform_test", + "//tensorflow/python:session", + "//third_party/py/numpy", + ], +) + py_test( name = "map_dataset_op_test", size = "medium", diff --git a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py new file mode 100644 index 0000000000..5d7d4da113 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py @@ -0,0 +1,76 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for LMDBDatasetOp.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import shutil + +from tensorflow.contrib.data.python.ops import readers +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.platform import resource_loader +from tensorflow.python.platform import test +from tensorflow.python.util import compat + +class LMDBDatasetTest(test.TestCase): + + def setUp(self): + super(LMDBDatasetTest, self).setUp() + path = os.path.join( + resource_loader.get_root_dir_with_all_resources(), + "tensorflow", + "core", + "lib", + "lmdb", + "testdata", + "data.mdb") + + print(path) + # Copy database out because we need the path to be writable to use locks. + self.db_path = os.path.join(self.get_temp_dir(), "data.mdb") + shutil.copy(path, self.db_path) + + def testReadFromFile(self): + filename = self.db_path + + filenames = constant_op.constant([filename], dtypes.string) + num_repeats = 2 + + dataset = readers.LMDBDataset( + filenames).repeat(num_repeats) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for _ in range(num_repeats): # Dataset is repeated. + for i in range(10): # 10 records. + k = compat.as_bytes(str(i)) + v = compat.as_bytes(str(chr(ord("a") + i))) + self.assertEqual((k, v), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + +if __name__ == "__main__": + test.main() -- cgit v1.2.3