From 1c4a48ddd49f78fbd8ea3defd3a8755c91284166 Mon Sep 17 00:00:00 2001 From: Rachel Lim Date: Fri, 28 Sep 2018 15:22:06 -0700 Subject: [tf.data] Merged contrib.data's DatasetTestBase with the DatasetTestBase in core (and added that as a base class for all the contrib tests). Also changed the assertDatasetsEqual functions so they are both graph and eager compatible (took the code from CSVDatasetTest) :) PiperOrigin-RevId: 215004892 --- tensorflow/contrib/data/python/kernel_tests/BUILD | 37 +++++++--- .../python/kernel_tests/batch_dataset_op_test.py | 9 +-- .../data/python/kernel_tests/bucketing_test.py | 9 +-- .../python/kernel_tests/csv_dataset_op_test.py | 43 ++---------- .../kernel_tests/dataset_constructor_op_test.py | 3 +- .../directed_interleave_dataset_test.py | 3 +- .../python/kernel_tests/get_single_element_test.py | 3 +- .../kernel_tests/indexed_dataset_ops_test.py | 3 +- .../kernel_tests/interleave_dataset_op_test.py | 3 +- .../data/python/kernel_tests/iterator_ops_test.py | 3 +- .../python/kernel_tests/lmdb_dataset_op_test.py | 3 +- .../python/kernel_tests/map_dataset_op_test.py | 3 +- .../data/python/kernel_tests/map_defun_op_test.py | 4 +- .../data/python/kernel_tests/optimization/BUILD | 9 ++- .../optimization/assert_next_dataset_op_test.py | 3 +- .../optimization/hoist_random_uniform_test.py | 3 +- .../optimization/map_and_filter_fusion_test.py | 3 +- .../optimization/map_parallelization_test.py | 3 +- .../optimization/map_vectorization_test.py | 14 ++-- .../optimization/model_dataset_op_test.py | 3 +- .../optimization/noop_elimination_test.py | 3 +- .../optimization/optimize_dataset_op_test.py | 3 +- .../data/python/kernel_tests/parsing_ops_test.py | 3 +- .../python/kernel_tests/prefetching_ops_test.py | 7 +- .../python/kernel_tests/range_dataset_op_test.py | 3 +- .../python/kernel_tests/reader_dataset_ops_test.py | 3 +- .../kernel_tests/reader_dataset_ops_test_base.py | 10 +-- .../data/python/kernel_tests/resample_test.py | 3 +- .../python/kernel_tests/scan_dataset_op_test.py | 3 +- .../python/kernel_tests/shuffle_dataset_op_test.py | 3 +- .../python/kernel_tests/slide_dataset_op_test.py | 8 +-- .../kernel_tests/sql_dataset_op_test_base.py | 5 +- .../python/kernel_tests/stats_dataset_test_base.py | 4 +- .../contrib/data/python/kernel_tests/test_utils.py | 73 -------------------- .../kernel_tests/threadpool_dataset_ops_test.py | 4 +- .../python/kernel_tests/unique_dataset_op_test.py | 3 +- .../python/kernel_tests/window_dataset_op_test.py | 3 +- .../data/python/kernel_tests/writer_ops_test.py | 3 +- tensorflow/python/data/kernel_tests/BUILD | 3 + tensorflow/python/data/kernel_tests/test_base.py | 80 ++++++++++++++++++++++ tensorflow/tools/pip_package/BUILD | 1 - 41 files changed, 209 insertions(+), 183 deletions(-) delete mode 100644 tensorflow/contrib/data/python/kernel_tests/test_utils.py diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 21ac40eb21..33784afa3f 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -31,6 +31,7 @@ py_test( "//tensorflow/python:string_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:util", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", @@ -54,6 +55,7 @@ py_test( "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", "//tensorflow/python:tensor_shape", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], @@ -77,6 +79,7 @@ py_test( "//tensorflow/python:platform", "//tensorflow/python:platform_test", "//tensorflow/python:session", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:readers", "//tensorflow/python/eager:context", "//third_party/py/numpy", @@ -97,6 +100,7 @@ py_test( "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/util:nest", ], @@ -112,6 +116,7 @@ py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:errors", "//tensorflow/python:random_seed", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], @@ -130,6 +135,7 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "@absl_py//absl/testing:parameterized", ], @@ -144,6 +150,7 @@ py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:experimental_dataset_ops_gen", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], @@ -169,6 +176,7 @@ py_test( "//tensorflow/python:script_ops", "//tensorflow/python:sparse_ops", "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "@six_archive//:six", ], @@ -188,6 +196,7 @@ py_test( "//tensorflow/python:framework_ops", "//tensorflow/python:training", "//tensorflow/python:variables", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/estimator:estimator_py", ], @@ -214,6 +223,7 @@ py_test( "//tensorflow/python:platform", "//tensorflow/python:platform_test", "//tensorflow/python:session", + "//tensorflow/python/data/kernel_tests:test_base", "//third_party/py/numpy", ], ) @@ -239,6 +249,7 @@ py_test( "//tensorflow/python:io_ops", "//tensorflow/python:math_ops", "//tensorflow/python:util", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], @@ -258,6 +269,7 @@ py_test( "//tensorflow/python:io_ops", "//tensorflow/python:math_ops", "//tensorflow/python:util", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], @@ -282,6 +294,7 @@ py_test( "//tensorflow/python:functional_ops", "//tensorflow/python:math_ops", "//tensorflow/python:session", + "//tensorflow/python/data/kernel_tests:test_base", ], ) @@ -300,6 +313,7 @@ py_test( "//tensorflow/python:parsing_ops", "//tensorflow/python:platform", "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/util:nest", "//third_party/py/numpy", @@ -315,6 +329,7 @@ cuda_py_test( "//tensorflow/core:protos_all_py", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", @@ -340,6 +355,7 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:tensor_shape", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", ], ) @@ -365,6 +381,7 @@ py_library( "//tensorflow/python:lib", "//tensorflow/python:parsing_ops", "//tensorflow/python:util", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:iterator_ops", "//tensorflow/python/data/ops:readers", ], @@ -411,6 +428,7 @@ py_test( "//tensorflow/python:random_ops", "//tensorflow/python:string_ops", "//tensorflow/python:util", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", @@ -433,6 +451,7 @@ py_test( "//tensorflow/python:errors", "//tensorflow/python:framework_test_lib", "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/eager:context", "//third_party/py/numpy", @@ -453,6 +472,7 @@ py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], @@ -470,6 +490,7 @@ py_test( "//tensorflow/python:errors", "//tensorflow/python:math_ops", "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", @@ -489,6 +510,7 @@ py_library( "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", + "//tensorflow/python/data/kernel_tests:test_base", "@org_sqlite//:python", ], ) @@ -533,6 +555,7 @@ py_library( deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:client_testlib", + "//tensorflow/python/data/kernel_tests:test_base", ], ) @@ -549,6 +572,7 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:script_ops", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", @@ -567,6 +591,7 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:util", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", ], ) @@ -587,6 +612,7 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:math_ops", "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", @@ -604,17 +630,8 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:lib", "//tensorflow/python:util", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:readers", ], ) - -py_library( - name = "test_utils", - srcs = ["test_utils.py"], - deps = [ - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python/data/util:nest", - ], -) diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index e2508de9e9..fed7de5f2b 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -25,6 +25,7 @@ import numpy as np from tensorflow.contrib.data.python.ops import batching from tensorflow.python.client import session +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -40,12 +41,8 @@ from tensorflow.python.platform import test from tensorflow.python.util import compat -class BatchDatasetTest(test.TestCase, parameterized.TestCase): +class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): - def assertSparseValuesEqual(self, a, b): - self.assertAllEqual(a.indices, b.indices) - self.assertAllEqual(a.values, b.values) - self.assertAllEqual(a.dense_shape, b.dense_shape) def testDenseToSparseBatchDataset(self): components = np.random.randint(12, size=(100,)).astype(np.int32) @@ -723,7 +720,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): self.assertAllEqual([element for _ in range(10)], sess.run(get_next)) -class RestructuredDatasetTest(test.TestCase): +class RestructuredDatasetTest(test_base.DatasetTestBase): def test_assert_element_shape(self): diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py index 48971f2ccc..ae401f786c 100644 --- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py @@ -22,6 +22,7 @@ import random import numpy as np from tensorflow.contrib.data.python.ops import grouping +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -35,7 +36,7 @@ from tensorflow.python.ops import string_ops from tensorflow.python.platform import test -class GroupByReducerTest(test.TestCase): +class GroupByReducerTest(test_base.DatasetTestBase): def checkResults(self, dataset, shapes, values): self.assertEqual(shapes, dataset.output_shapes) @@ -198,7 +199,7 @@ class GroupByReducerTest(test.TestCase): self.assertEqual(y, 45) -class GroupByWindowTest(test.TestCase): +class GroupByWindowTest(test_base.DatasetTestBase): def testSimple(self): components = np.random.randint(100, size=(200,)).astype(np.int64) @@ -345,7 +346,7 @@ class GroupByWindowTest(test.TestCase): # NOTE(mrry): These tests are based on the tests in bucket_ops_test.py. # Currently, they use a constant batch size, though should be made to use a # different batch size per key. -class BucketTest(test.TestCase): +class BucketTest(test_base.DatasetTestBase): def _dynamicPad(self, bucket, window, window_size): # TODO(mrry): To match `tf.contrib.training.bucket()`, implement a @@ -570,7 +571,7 @@ def _get_record_shape(sparse): return tensor_shape.TensorShape([None]) -class BucketBySequenceLength(test.TestCase): +class BucketBySequenceLength(test_base.DatasetTestBase): def testBucket(self): diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py index f8e74e4583..5b3c512b64 100644 --- a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py @@ -30,6 +30,7 @@ import numpy as np from tensorflow.contrib.data.python.ops import error_ops from tensorflow.contrib.data.python.ops import readers from tensorflow.python.client import session +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import readers as core_readers from tensorflow.python.eager import context from tensorflow.python.framework import constant_op @@ -43,37 +44,7 @@ from tensorflow.python.platform import test @test_util.run_all_in_graph_and_eager_modes -class CsvDatasetOpTest(test.TestCase): - - def _get_next(self, dataset): - # Returns a no argument function whose result is fed to self.evaluate to - # yield the next element - it = dataset.make_one_shot_iterator() - if context.executing_eagerly(): - return it.get_next - else: - get_next = it.get_next() - return lambda: get_next - - def _assert_datasets_equal(self, ds1, ds2): - assert ds1.output_shapes == ds2.output_shapes, ('output_shapes differ: %s, ' - '%s') % (ds1.output_shapes, - ds2.output_shapes) - assert ds1.output_types == ds2.output_types - assert ds1.output_classes == ds2.output_classes - next1 = self._get_next(ds1) - next2 = self._get_next(ds2) - # Run through datasets and check that outputs match, or errors match. - while True: - try: - op1 = self.evaluate(next1()) - except (errors.OutOfRangeError, ValueError) as e: - # If op1 throws an exception, check that op2 throws same exception. - with self.assertRaises(type(e)): - self.evaluate(next2()) - break - op2 = self.evaluate(next2()) - self.assertAllEqual(op1, op2) +class CsvDatasetOpTest(test_base.DatasetTestBase): def _setup_files(self, inputs, linebreak='\n', compression_type=None): filenames = [] @@ -108,7 +79,7 @@ class CsvDatasetOpTest(test.TestCase): """Checks that CsvDataset is equiv to TextLineDataset->map(decode_csv).""" dataset_actual, dataset_expected = self._make_test_datasets( inputs, **kwargs) - self._assert_datasets_equal(dataset_actual, dataset_expected) + self.assertDatasetsEqual(dataset_actual, dataset_expected) def _verify_output_or_err(self, dataset, @@ -116,7 +87,7 @@ class CsvDatasetOpTest(test.TestCase): expected_err_re=None): if expected_err_re is None: # Verify that output is expected, without errors - nxt = self._get_next(dataset) + nxt = self.getNext(dataset) expected_output = [[ v.encode('utf-8') if isinstance(v, str) else v for v in op ] for op in expected_output] @@ -128,7 +99,7 @@ class CsvDatasetOpTest(test.TestCase): else: # Verify that OpError is produced as expected with self.assertRaisesOpError(expected_err_re): - nxt = self._get_next(dataset) + nxt = self.getNext(dataset) while True: try: self.evaluate(nxt()) @@ -354,7 +325,7 @@ class CsvDatasetOpTest(test.TestCase): inputs = [['1,,3,4', '5,6,,8']] ds_actual, ds_expected = self._make_test_datasets( inputs, record_defaults=record_defaults) - self._assert_datasets_equal( + self.assertDatasetsEqual( ds_actual.repeat(5).prefetch(1), ds_expected.repeat(5).prefetch(1)) @@ -377,7 +348,7 @@ class CsvDatasetOpTest(test.TestCase): ds = readers.make_csv_dataset( file_path, batch_size=1, shuffle=False, num_epochs=1) - nxt = self._get_next(ds) + nxt = self.getNext(ds) result = list(self.evaluate(nxt()).values()) diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py index a2ab3de52e..722e87e555 100644 --- a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.data.python.ops import batching +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.framework import dtypes @@ -25,7 +26,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class DatasetConstructorTest(test.TestCase): +class DatasetConstructorTest(test_base.DatasetTestBase): def testRestructureDataset(self): components = (array_ops.placeholder(dtypes.int32), diff --git a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py index eb110324d1..bc10c21472 100644 --- a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py @@ -20,13 +20,14 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.data.python.ops import interleave_ops +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors from tensorflow.python.framework import random_seed from tensorflow.python.platform import test -class DirectedInterleaveDatasetTest(test.TestCase): +class DirectedInterleaveDatasetTest(test_base.DatasetTestBase): def testBasic(self): selector_dataset = dataset_ops.Dataset.range(10).repeat(100) diff --git a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py index f3968cdc15..cc22ea1df7 100644 --- a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.contrib.data.python.ops import get_single_element from tensorflow.contrib.data.python.ops import grouping +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -30,7 +31,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class GetSingleElementTest(test.TestCase, parameterized.TestCase): +class GetSingleElementTest(test_base.DatasetTestBase, parameterized.TestCase): @parameterized.named_parameters( ("Zero", 0, 1), diff --git a/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py index 46a7127b52..d4d3d4adb2 100644 --- a/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import unittest from tensorflow.contrib.data.python.ops import indexed_dataset_ops +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -28,7 +29,7 @@ from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops from tensorflow.python.platform import test -class IndexedDatasetOpsTest(test.TestCase): +class IndexedDatasetOpsTest(test_base.DatasetTestBase): def testLowLevelIndexedDatasetOps(self): identity = ged_ops.experimental_identity_indexed_dataset( diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py index b9e74dfddb..28bd670ab5 100644 --- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py @@ -25,6 +25,7 @@ import time from six.moves import zip_longest from tensorflow.contrib.data.python.ops import interleave_ops +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -36,7 +37,7 @@ from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test -class ParallelInterleaveDatasetTest(test.TestCase): +class ParallelInterleaveDatasetTest(test_base.DatasetTestBase): def setUp(self): diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py index 7e2326bd17..58a1d7c93b 100644 --- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.data.python.ops import iterator_ops +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import estimator from tensorflow.python.estimator import model_fn @@ -33,7 +34,7 @@ from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import training_util -class CheckpointInputPipelineHookTest(test.TestCase): +class CheckpointInputPipelineHookTest(test_base.DatasetTestBase): @staticmethod def _model_fn(features, labels, mode, config): diff --git a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py index 1cc5ddc9a2..d2a72272db 100644 --- a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py @@ -22,6 +22,7 @@ import os import shutil from tensorflow.contrib.data.python.ops import readers +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -31,7 +32,7 @@ from tensorflow.python.util import compat prefix_path = "tensorflow/core/lib" -class LMDBDatasetTest(test.TestCase): +class LMDBDatasetTest(test_base.DatasetTestBase): def setUp(self): super(LMDBDatasetTest, self).setUp() diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py index e8519381d6..385c4ef6ea 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py @@ -29,6 +29,7 @@ from tensorflow.contrib.data.python.ops import error_ops from tensorflow.contrib.data.python.ops import optimization from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -41,7 +42,7 @@ from tensorflow.python.util import compat _NUMPY_RANDOM_SEED = 42 -class MapDatasetTest(test.TestCase): +class MapDatasetTest(test_base.DatasetTestBase): def testMapIgnoreError(self): components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) diff --git a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py index 25aea0393f..751e6d5b30 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py @@ -21,6 +21,7 @@ import time from tensorflow.contrib.data.python.ops import map_defun from tensorflow.python.client import session +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -33,7 +34,8 @@ from tensorflow.python.ops import functional_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class MapDefunTest(test.TestCase): + +class MapDefunTest(test_base.DatasetTestBase): def testMapDefunSimple(self): diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD index 1ae92bdeff..d7b5edcd9a 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD @@ -15,6 +15,7 @@ py_test( "//tensorflow/contrib/data/python/ops:optimization", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", ], ) @@ -31,6 +32,7 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:math_ops", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "@absl_py//absl/testing:parameterized", ], @@ -57,7 +59,6 @@ py_test( srcs = ["map_vectorization_test.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/data/python/kernel_tests:test_utils", "//tensorflow/contrib/data/python/ops:optimization", "//tensorflow/python:check_ops", "//tensorflow/python:client_testlib", @@ -67,6 +68,7 @@ py_test( "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python:session", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", @@ -85,6 +87,7 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:math_ops", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "@absl_py//absl/testing:parameterized", ], @@ -102,6 +105,7 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:math_ops", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "@absl_py//absl/testing:parameterized", ], @@ -121,6 +125,7 @@ py_test( "//tensorflow/contrib/data/python/ops:optimization", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], @@ -137,6 +142,7 @@ py_test( "//tensorflow/contrib/data/python/ops:optimization", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], @@ -151,6 +157,7 @@ py_test( "//tensorflow/contrib/data/python/ops:optimization", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py index d10da80442..fe1b5280ba 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py @@ -18,12 +18,13 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.data.python.ops import optimization +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors from tensorflow.python.platform import test -class AssertNextDatasetTest(test.TestCase): +class AssertNextDatasetTest(test_base.DatasetTestBase): def testAssertNext(self): dataset = dataset_ops.Dataset.from_tensors(0).apply( diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py index 9518c2e1ad..b43efb5c7c 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from absl.testing import parameterized from tensorflow.contrib.data.python.ops import optimization +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -31,7 +32,7 @@ from tensorflow.python.ops import random_ops from tensorflow.python.platform import test -class HoistRandomUniformTest(test.TestCase, parameterized.TestCase): +class HoistRandomUniformTest(test_base.DatasetTestBase, parameterized.TestCase): @staticmethod def map_functions(): diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py index e75edf6086..e9e3fc81e5 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from absl.testing import parameterized from tensorflow.contrib.data.python.ops import optimization +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -28,7 +29,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase): +class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase): @staticmethod def map_functions(): diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py index dd547db086..f7907eb890 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from absl.testing import parameterized from tensorflow.contrib.data.python.ops import optimization +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -30,7 +31,7 @@ from tensorflow.python.ops import random_ops from tensorflow.python.platform import test -class MapParallelizationTest(test.TestCase, parameterized.TestCase): +class MapParallelizationTest(test_base.DatasetTestBase, parameterized.TestCase): @staticmethod def map_functions(): diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py index 5b493f44c9..a5ea85f454 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py @@ -22,9 +22,9 @@ import time from absl.testing import parameterized import numpy as np -from tensorflow.contrib.data.python.kernel_tests import test_utils from tensorflow.contrib.data.python.ops import optimization from tensorflow.python.client import session +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -36,7 +36,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase): +class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase): def _get_test_datasets(self, base_dataset, @@ -85,7 +85,7 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase): [3, 4]]).repeat(5) unoptimized, optimized = self._get_test_datasets(base_dataset, map_fn, num_parallel_calls) - self._assert_datasets_equal(unoptimized, optimized) + self.assertDatasetsEqual(unoptimized, optimized) def testOptimizationBadMapFn(self): # Test map functions that give an error @@ -112,7 +112,7 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase): # TODO(rachelim): when this optimization works, turn on expect_optimized unoptimized, optimized = self._get_test_datasets( base_dataset, map_fn, expect_optimized=False) - self._assert_datasets_equal(optimized, unoptimized) + self.assertDatasetsEqual(optimized, unoptimized) def testOptimizationIgnoreStateful(self): @@ -124,7 +124,7 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase): [3, 4]]).repeat(5) unoptimized, optimized = self._get_test_datasets( base_dataset, map_fn, expect_optimized=False) - self._assert_datasets_raise_same_error( + self.assertDatasetsRaiseSameError( unoptimized, optimized, errors.InvalidArgumentError, [("OneShotIterator", "OneShotIterator_1", 1), ("IteratorGetNext", "IteratorGetNext_1", 1)]) @@ -138,7 +138,7 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase): base_dataset = dataset_ops.Dataset.range(20).batch(3, drop_remainder=False) unoptimized, optimized = self._get_test_datasets( base_dataset, map_fn, expect_optimized=False) - self._assert_datasets_equal(unoptimized, optimized) + self.assertDatasetsEqual(unoptimized, optimized) def testOptimizationIgnoreRaggedMap(self): # Don't optimize when the output of the map fn shapes are unknown. @@ -148,7 +148,7 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase): base_dataset = dataset_ops.Dataset.range(20).batch(1, drop_remainder=True) unoptimized, optimized = self._get_test_datasets( base_dataset, map_fn, expect_optimized=False) - self._assert_datasets_raise_same_error( + self.assertDatasetsRaiseSameError( unoptimized, optimized, errors.InvalidArgumentError, [("OneShotIterator", "OneShotIterator_1", 1), ("IteratorGetNext", "IteratorGetNext_1", 1)]) diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py index 3b62a7e468..33c250ab2a 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py @@ -23,12 +23,13 @@ import numpy as np from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.data.python.ops import optimization +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class ModelDatasetTest(test.TestCase): +class ModelDatasetTest(test_base.DatasetTestBase): def testModelMap(self): k = 1024 * 1024 diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py index 507feda3ad..b9e60cfa4e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.data.python.ops import optimization +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -26,7 +27,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class NoopEliminationTest(test.TestCase): +class NoopEliminationTest(test_base.DatasetTestBase): def testNoopElimination(self): a = constant_op.constant(1, dtype=dtypes.int64) diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py index a3fb824ce9..04f499f8c5 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.data.python.ops import optimization +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -28,7 +29,7 @@ from tensorflow.python.ops import random_ops from tensorflow.python.platform import test -class OptimizeDatasetTest(test.TestCase): +class OptimizeDatasetTest(test_base.DatasetTestBase): def testOptimizationDefault(self): dataset = dataset_ops.Dataset.range(10).apply( diff --git a/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py index c4623bca73..66ccaceea5 100644 --- a/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py @@ -25,6 +25,7 @@ import numpy as np from tensorflow.contrib.data.python.ops import parsing_ops as contrib_parsing_ops from tensorflow.core.example import example_pb2 from tensorflow.core.example import feature_pb2 +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.framework import dtypes @@ -72,7 +73,7 @@ def _compare_output_to_expected(tester, dict_tensors, expected_tensors, i += 1 -class ParseExampleTest(test.TestCase): +class ParseExampleTest(test_base.DatasetTestBase): def _test(self, input_tensor, diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py index 33a64ea767..7a6a7a709a 100644 --- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py @@ -22,6 +22,7 @@ import threading from tensorflow.contrib.data.python.ops import prefetching_ops from tensorflow.core.protobuf import config_pb2 from tensorflow.python.compat import compat +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import constant_op @@ -35,7 +36,7 @@ from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test -class PrefetchingKernelsOpsTest(test.TestCase): +class PrefetchingKernelsOpsTest(test_base.DatasetTestBase): def setUp(self): self._event = threading.Event() @@ -244,7 +245,7 @@ class PrefetchingKernelsOpsTest(test.TestCase): sess.run(destroy_op) -class PrefetchToDeviceTest(test.TestCase): +class PrefetchToDeviceTest(test_base.DatasetTestBase): def testPrefetchToDevice(self): host_dataset = dataset_ops.Dataset.range(10) @@ -445,7 +446,7 @@ class PrefetchToDeviceTest(test.TestCase): sess.run(next_element) -class CopyToDeviceTest(test.TestCase): +class CopyToDeviceTest(test_base.DatasetTestBase): def testCopyToDevice(self): host_dataset = dataset_ops.Dataset.range(10) diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py index db8fe6aa1b..2e901587f4 100644 --- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py @@ -19,6 +19,7 @@ from __future__ import print_function from tensorflow.contrib.data.python.ops import counter from tensorflow.contrib.data.python.ops import enumerate_ops +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -27,7 +28,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.platform import test -class RangeDatasetTest(test.TestCase): +class RangeDatasetTest(test_base.DatasetTestBase): def testEnumerateDataset(self): components = (["a", "b"], [1, 2], [37.0, 38]) diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py index ed75b27a44..66ed547b6d 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py @@ -25,6 +25,7 @@ import numpy as np from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base from tensorflow.contrib.data.python.ops import readers +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import readers as core_readers from tensorflow.python.data.util import nest from tensorflow.python.framework import constant_op @@ -242,7 +243,7 @@ class ReadBatchFeaturesTest( self.assertEqual(32, shape[0]) -class MakeCsvDatasetTest(test.TestCase): +class MakeCsvDatasetTest(test_base.DatasetTestBase): def _make_csv_dataset(self, filenames, batch_size, num_epochs=1, **kwargs): return readers.make_csv_dataset( diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py index 08b9f03816..f443b5501b 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py @@ -25,6 +25,7 @@ import zlib from tensorflow.contrib.data.python.ops import readers from tensorflow.core.example import example_pb2 from tensorflow.core.example import feature_pb2 +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.ops import readers as core_readers from tensorflow.python.framework import constant_op @@ -32,11 +33,10 @@ from tensorflow.python.framework import dtypes from tensorflow.python.lib.io import python_io from tensorflow.python.ops import array_ops from tensorflow.python.ops import parsing_ops -from tensorflow.python.platform import test from tensorflow.python.util import compat -class FixedLengthRecordDatasetTestBase(test.TestCase): +class FixedLengthRecordDatasetTestBase(test_base.DatasetTestBase): """Base class for setting up and testing FixedLengthRecordDataset.""" def setUp(self): @@ -63,7 +63,7 @@ class FixedLengthRecordDatasetTestBase(test.TestCase): return filenames -class ReadBatchFeaturesTestBase(test.TestCase): +class ReadBatchFeaturesTestBase(test_base.DatasetTestBase): """Base class for setting up and testing `make_batched_feature_dataset`.""" def setUp(self): @@ -273,7 +273,7 @@ class ReadBatchFeaturesTestBase(test.TestCase): self.assertAllEqual(expected_batch[i], actual_batch[i]) -class TextLineDatasetTestBase(test.TestCase): +class TextLineDatasetTestBase(test_base.DatasetTestBase): """Base class for setting up and testing TextLineDataset.""" def _lineText(self, f, l): @@ -313,7 +313,7 @@ class TextLineDatasetTestBase(test.TestCase): return filenames -class TFRecordDatasetTestBase(test.TestCase): +class TFRecordDatasetTestBase(test_base.DatasetTestBase): """Base class for setting up and testing TFRecordDataset.""" def setUp(self): diff --git a/tensorflow/contrib/data/python/kernel_tests/resample_test.py b/tensorflow/contrib/data/python/kernel_tests/resample_test.py index 16b1441baa..32474bd411 100644 --- a/tensorflow/contrib/data/python/kernel_tests/resample_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/resample_test.py @@ -24,6 +24,7 @@ import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.data.python.ops import resampling +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -57,7 +58,7 @@ def _time_resampling( return end_time - start_time -class ResampleTest(test.TestCase, parameterized.TestCase): +class ResampleTest(test_base.DatasetTestBase, parameterized.TestCase): @parameterized.named_parameters( ("InitialDistributionKnown", True), diff --git a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py index dde678bd54..bdf80eae4e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py @@ -22,6 +22,7 @@ import itertools import numpy as np from tensorflow.contrib.data.python.ops import scan_ops +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context from tensorflow.python.framework import constant_op @@ -33,7 +34,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class ScanDatasetTest(test.TestCase): +class ScanDatasetTest(test_base.DatasetTestBase): def _counting_dataset(self, start, scan_fn): return dataset_ops.Dataset.from_tensors(0).repeat().apply( diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py index 440e48db30..c97002a255 100644 --- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py @@ -20,13 +20,14 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.data.python.ops import shuffle_ops +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.platform import test -class ShuffleAndRepeatTest(test.TestCase): +class ShuffleAndRepeatTest(test_base.DatasetTestBase): def _build_ds(self, seed, count=5, num_elements=20): return dataset_ops.Dataset.range(num_elements).apply( diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py index 90d18dca2a..c5a7862322 100644 --- a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py @@ -21,6 +21,7 @@ from absl.testing import parameterized import numpy as np from tensorflow.contrib.data.python.ops import sliding +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -30,7 +31,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class SlideDatasetTest(test.TestCase, parameterized.TestCase): +class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): @parameterized.named_parameters( ("1", 20, 14, 7, 1), @@ -197,11 +198,6 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase): sliding.sliding_window_batch( window_size=1, stride=1, window_shift=1, window_stride=1)) - def assertSparseValuesEqual(self, a, b): - self.assertAllEqual(a.indices, b.indices) - self.assertAllEqual(a.values, b.values) - self.assertAllEqual(a.dense_shape, b.dense_shape) - def testSlideSparse(self): def _sparse(i): diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py index 1f5c725a92..319a2ea263 100644 --- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py +++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py @@ -24,12 +24,13 @@ import os import sqlite3 from tensorflow.contrib.data.python.ops import readers +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class SqlDatasetTestBase(test.TestCase): +class SqlDatasetTestBase(test_base.DatasetTestBase): """Base class for setting up and testing SqlDataset.""" def _createSqlDataset(self, output_types, num_repeats=1): @@ -92,5 +93,3 @@ class SqlDatasetTestBase(test.TestCase): 9007199254740992.0)]) conn.commit() conn.close() - - diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py index b1b4c23510..80f2625927 100644 --- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py @@ -19,10 +19,10 @@ from __future__ import print_function from tensorflow.core.framework import summary_pb2 -from tensorflow.python.platform import test +from tensorflow.python.data.kernel_tests import test_base -class StatsDatasetTestBase(test.TestCase): +class StatsDatasetTestBase(test_base.DatasetTestBase): """Base class for testing statistics gathered in `StatsAggregator`.""" def _assertSummaryContains(self, summary_str, tag): diff --git a/tensorflow/contrib/data/python/kernel_tests/test_utils.py b/tensorflow/contrib/data/python/kernel_tests/test_utils.py deleted file mode 100644 index 4c3353fe40..0000000000 --- a/tensorflow/contrib/data/python/kernel_tests/test_utils.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Test utilities for tf.data functionality.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import re - -from tensorflow.python.data.util import nest -from tensorflow.python.framework import errors -from tensorflow.python.platform import test - - -class DatasetTestBase(test.TestCase): - """Base class for dataset tests.""" - - def _assert_datasets_equal(self, dataset1, dataset2): - # TODO(rachelim): support sparse tensor outputs - next1 = dataset1.make_one_shot_iterator().get_next() - next2 = dataset2.make_one_shot_iterator().get_next() - with self.cached_session() as sess: - while True: - try: - op1 = sess.run(next1) - except errors.OutOfRangeError: - with self.assertRaises(errors.OutOfRangeError): - sess.run(next2) - break - op2 = sess.run(next2) - - op1 = nest.flatten(op1) - op2 = nest.flatten(op2) - assert len(op1) == len(op2) - for i in range(len(op1)): - self.assertAllEqual(op1[i], op2[i]) - - def _assert_datasets_raise_same_error(self, - dataset1, - dataset2, - exception_class, - replacements=None): - # We are defining next1 and next2 in the same line so that we get identical - # file:line_number in the error messages - # pylint: disable=line-too-long - next1, next2 = dataset1.make_one_shot_iterator().get_next(), dataset2.make_one_shot_iterator().get_next() - # pylint: enable=line-too-long - with self.cached_session() as sess: - try: - sess.run(next1) - raise ValueError( - "Expected dataset to raise an error of type %s, but it did not." % - repr(exception_class)) - except exception_class as e: - expected_message = e.message - for old, new, count in replacements: - expected_message = expected_message.replace(old, new, count) - # Check that the first segment of the error messages are the same. - with self.assertRaisesRegexp(exception_class, - re.escape(expected_message)): - sess.run(next2) diff --git a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py index 8d335e87d5..08de3a9143 100644 --- a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py @@ -24,6 +24,7 @@ import numpy as np from tensorflow.contrib.data.python.ops import threadpool from tensorflow.contrib.data.python.ops import unique +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -31,7 +32,8 @@ from tensorflow.python.ops import script_ops from tensorflow.python.platform import test -class OverrideThreadpoolDatasetTest(test.TestCase, parameterized.TestCase): +class OverrideThreadpoolDatasetTest(test_base.DatasetTestBase, + parameterized.TestCase): @parameterized.named_parameters( ("1", 1, None), diff --git a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py index f994c8563f..8856ce5afb 100644 --- a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.data.python.ops import unique +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -25,7 +26,7 @@ from tensorflow.python.platform import test from tensorflow.python.util import compat -class UniqueDatasetTest(test.TestCase): +class UniqueDatasetTest(test_base.DatasetTestBase): def _testSimpleHelper(self, dtype, test_cases): """Test the `unique()` transformation on a list of test cases. diff --git a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py index 8b7b3ac0f7..79134c7bc6 100644 --- a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.data.python.ops import grouping +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -31,7 +32,7 @@ from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test -class WindowDatasetTest(test.TestCase, parameterized.TestCase): +class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): def _structuredDataset(self, structure, shape, dtype): if structure is None: diff --git a/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py index 867ee2ba37..fca546a570 100644 --- a/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import os from tensorflow.contrib.data.python.ops import writers +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers from tensorflow.python.framework import dtypes @@ -30,7 +31,7 @@ from tensorflow.python.platform import test from tensorflow.python.util import compat -class TFRecordWriterTest(test.TestCase): +class TFRecordWriterTest(test_base.DatasetTestBase): def setUp(self): super(TFRecordWriterTest, self).setUp() diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD index 5f9818566f..cadfe7f9e0 100644 --- a/tensorflow/python/data/kernel_tests/BUILD +++ b/tensorflow/python/data/kernel_tests/BUILD @@ -471,6 +471,9 @@ py_library( srcs = ["test_base.py"], deps = [ "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/util:nest", ], ) diff --git a/tensorflow/python/data/kernel_tests/test_base.py b/tensorflow/python/data/kernel_tests/test_base.py index b4f64115b7..b730e10949 100644 --- a/tensorflow/python/data/kernel_tests/test_base.py +++ b/tensorflow/python/data/kernel_tests/test_base.py @@ -17,6 +17,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import re + +from tensorflow.python.data.util import nest +from tensorflow.python.eager import context +from tensorflow.python.framework import errors +from tensorflow.python.framework import sparse_tensor from tensorflow.python.platform import test @@ -24,6 +30,80 @@ class DatasetTestBase(test.TestCase): """Base class for dataset tests.""" def assertSparseValuesEqual(self, a, b): + """Asserts that two SparseTensors/SparseTensorValues are equal.""" self.assertAllEqual(a.indices, b.indices) self.assertAllEqual(a.values, b.values) self.assertAllEqual(a.dense_shape, b.dense_shape) + + def getNext(self, dataset): + """Returns a callable that returns the next element of the dataset. + + Example use: + ```python + # In both graph and eager modes + dataset = ... + nxt = self.getNext(dataset) + result = self.evaluate(nxt()) + ``` + + Args: + dataset: A dataset whose next element is returned + + Returns: + A callable that returns the next element of `dataset` + """ + it = dataset.make_one_shot_iterator() + if context.executing_eagerly(): + return it.get_next + else: + nxt = it.get_next() + return lambda: nxt + + def assertDatasetsEqual(self, dataset1, dataset2): + """Checks that datasets are equal. Supports both graph and eager mode.""" + self.assertEqual(dataset1.output_types, dataset2.output_types) + self.assertEqual(dataset1.output_classes, dataset2.output_classes) + + next1 = self.getNext(dataset1) + next2 = self.getNext(dataset2) + while True: + try: + op1 = self.evaluate(next1()) + except errors.OutOfRangeError: + with self.assertRaises(errors.OutOfRangeError): + self.evaluate(next2()) + break + op2 = self.evaluate(next2()) + + op1 = nest.flatten(op1) + op2 = nest.flatten(op2) + assert len(op1) == len(op2) + for i in range(len(op1)): + if isinstance( + op1[i], + (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)): + self.assertSparseValuesEqual(op1[i], op2[i]) + else: + self.assertAllEqual(op1[i], op2[i]) + + def assertDatasetsRaiseSameError(self, + dataset1, + dataset2, + exception_class, + replacements=None): + """Checks that datasets raise the same error on the first get_next call.""" + next1 = self.getNext(dataset1) + next2 = self.getNext(dataset2) + try: + self.evaluate(next1()) + raise ValueError( + 'Expected dataset to raise an error of type %s, but it did not.' % + repr(exception_class)) + except exception_class as e: + expected_message = e.message + for old, new, count in replacements: + expected_message = expected_message.replace(old, new, count) + # Check that the first segment of the error messages are the same. + with self.assertRaisesRegexp(exception_class, + re.escape(expected_message)): + self.evaluate(next2()) diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index 7d925a8fef..c621812535 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -66,7 +66,6 @@ COMMON_PIP_DEPS = [ "//tensorflow/contrib/constrained_optimization:constrained_optimization_pip", "//tensorflow/contrib/data/python/kernel_tests/serialization:dataset_serialization_test_base", "//tensorflow/contrib/data/python/kernel_tests:stats_dataset_test_base", - "//tensorflow/contrib/data/python/kernel_tests:test_utils", "//tensorflow/contrib/eager/python/examples:examples_pip", "//tensorflow/contrib/eager/python:evaluator", "//tensorflow/contrib/gan:gan", -- cgit v1.2.3