aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Rachel Lim <rachelim@google.com>2018-09-28 15:22:06 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-28 15:25:06 -0700
commit1c4a48ddd49f78fbd8ea3defd3a8755c91284166 (patch)
tree32c6e79a58dd3b6dc62c4eb225651c6fb003d16a /tensorflow
parent2f559f2d5f75cf80183ae0d855110809404019f7 (diff)
[tf.data] Merged contrib.data's DatasetTestBase with the DatasetTestBase in core (and added that as a base class for all the contrib tests). Also changed the assertDatasetsEqual functions so they are both graph and eager compatible (took the code from CSVDatasetTest) :)
PiperOrigin-RevId: 215004892
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD37
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py9
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/bucketing_test.py9
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py43
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/BUILD9
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py14
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py7
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py10
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/resample_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py8
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py5
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/test_utils.py73
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/BUILD3
-rw-r--r--tensorflow/python/data/kernel_tests/test_base.py80
-rw-r--r--tensorflow/tools/pip_package/BUILD1
41 files changed, 209 insertions, 183 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index 21ac40eb21..33784afa3f 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -31,6 +31,7 @@ py_test(
"//tensorflow/python:string_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:util",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
@@ -54,6 +55,7 @@ py_test(
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:string_ops",
"//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -77,6 +79,7 @@ py_test(
"//tensorflow/python:platform",
"//tensorflow/python:platform_test",
"//tensorflow/python:session",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:readers",
"//tensorflow/python/eager:context",
"//third_party/py/numpy",
@@ -97,6 +100,7 @@ py_test(
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
],
@@ -112,6 +116,7 @@ py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:random_seed",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -130,6 +135,7 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
],
@@ -144,6 +150,7 @@ py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -169,6 +176,7 @@ py_test(
"//tensorflow/python:script_ops",
"//tensorflow/python:sparse_ops",
"//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@six_archive//:six",
],
@@ -188,6 +196,7 @@ py_test(
"//tensorflow/python:framework_ops",
"//tensorflow/python:training",
"//tensorflow/python:variables",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/estimator:estimator_py",
],
@@ -214,6 +223,7 @@ py_test(
"//tensorflow/python:platform",
"//tensorflow/python:platform_test",
"//tensorflow/python:session",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//third_party/py/numpy",
],
)
@@ -239,6 +249,7 @@ py_test(
"//tensorflow/python:io_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:util",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -258,6 +269,7 @@ py_test(
"//tensorflow/python:io_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:util",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -282,6 +294,7 @@ py_test(
"//tensorflow/python:functional_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:session",
+ "//tensorflow/python/data/kernel_tests:test_base",
],
)
@@ -300,6 +313,7 @@ py_test(
"//tensorflow/python:parsing_ops",
"//tensorflow/python:platform",
"//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
"//third_party/py/numpy",
@@ -315,6 +329,7 @@ cuda_py_test(
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
@@ -340,6 +355,7 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
],
)
@@ -365,6 +381,7 @@ py_library(
"//tensorflow/python:lib",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:util",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:iterator_ops",
"//tensorflow/python/data/ops:readers",
],
@@ -411,6 +428,7 @@ py_test(
"//tensorflow/python:random_ops",
"//tensorflow/python:string_ops",
"//tensorflow/python:util",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
@@ -433,6 +451,7 @@ py_test(
"//tensorflow/python:errors",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/eager:context",
"//third_party/py/numpy",
@@ -453,6 +472,7 @@ py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -470,6 +490,7 @@ py_test(
"//tensorflow/python:errors",
"//tensorflow/python:math_ops",
"//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
@@ -489,6 +510,7 @@ py_library(
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
+ "//tensorflow/python/data/kernel_tests:test_base",
"@org_sqlite//:python",
],
)
@@ -533,6 +555,7 @@ py_library(
deps = [
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/kernel_tests:test_base",
],
)
@@ -549,6 +572,7 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:script_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
@@ -567,6 +591,7 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:util",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
],
)
@@ -587,6 +612,7 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:math_ops",
"//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
@@ -604,17 +630,8 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:lib",
"//tensorflow/python:util",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:readers",
],
)
-
-py_library(
- name = "test_utils",
- srcs = ["test_utils.py"],
- deps = [
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python/data/util:nest",
- ],
-)
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
index e2508de9e9..fed7de5f2b 100644
--- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
@@ -25,6 +25,7 @@ import numpy as np
from tensorflow.contrib.data.python.ops import batching
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -40,12 +41,8 @@ from tensorflow.python.platform import test
from tensorflow.python.util import compat
-class BatchDatasetTest(test.TestCase, parameterized.TestCase):
+class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
- def assertSparseValuesEqual(self, a, b):
- self.assertAllEqual(a.indices, b.indices)
- self.assertAllEqual(a.values, b.values)
- self.assertAllEqual(a.dense_shape, b.dense_shape)
def testDenseToSparseBatchDataset(self):
components = np.random.randint(12, size=(100,)).astype(np.int32)
@@ -723,7 +720,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual([element for _ in range(10)], sess.run(get_next))
-class RestructuredDatasetTest(test.TestCase):
+class RestructuredDatasetTest(test_base.DatasetTestBase):
def test_assert_element_shape(self):
diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
index 48971f2ccc..ae401f786c 100644
--- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
@@ -22,6 +22,7 @@ import random
import numpy as np
from tensorflow.contrib.data.python.ops import grouping
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -35,7 +36,7 @@ from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
-class GroupByReducerTest(test.TestCase):
+class GroupByReducerTest(test_base.DatasetTestBase):
def checkResults(self, dataset, shapes, values):
self.assertEqual(shapes, dataset.output_shapes)
@@ -198,7 +199,7 @@ class GroupByReducerTest(test.TestCase):
self.assertEqual(y, 45)
-class GroupByWindowTest(test.TestCase):
+class GroupByWindowTest(test_base.DatasetTestBase):
def testSimple(self):
components = np.random.randint(100, size=(200,)).astype(np.int64)
@@ -345,7 +346,7 @@ class GroupByWindowTest(test.TestCase):
# NOTE(mrry): These tests are based on the tests in bucket_ops_test.py.
# Currently, they use a constant batch size, though should be made to use a
# different batch size per key.
-class BucketTest(test.TestCase):
+class BucketTest(test_base.DatasetTestBase):
def _dynamicPad(self, bucket, window, window_size):
# TODO(mrry): To match `tf.contrib.training.bucket()`, implement a
@@ -570,7 +571,7 @@ def _get_record_shape(sparse):
return tensor_shape.TensorShape([None])
-class BucketBySequenceLength(test.TestCase):
+class BucketBySequenceLength(test_base.DatasetTestBase):
def testBucket(self):
diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
index f8e74e4583..5b3c512b64 100644
--- a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
@@ -30,6 +30,7 @@ import numpy as np
from tensorflow.contrib.data.python.ops import error_ops
from tensorflow.contrib.data.python.ops import readers
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
@@ -43,37 +44,7 @@ from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
-class CsvDatasetOpTest(test.TestCase):
-
- def _get_next(self, dataset):
- # Returns a no argument function whose result is fed to self.evaluate to
- # yield the next element
- it = dataset.make_one_shot_iterator()
- if context.executing_eagerly():
- return it.get_next
- else:
- get_next = it.get_next()
- return lambda: get_next
-
- def _assert_datasets_equal(self, ds1, ds2):
- assert ds1.output_shapes == ds2.output_shapes, ('output_shapes differ: %s, '
- '%s') % (ds1.output_shapes,
- ds2.output_shapes)
- assert ds1.output_types == ds2.output_types
- assert ds1.output_classes == ds2.output_classes
- next1 = self._get_next(ds1)
- next2 = self._get_next(ds2)
- # Run through datasets and check that outputs match, or errors match.
- while True:
- try:
- op1 = self.evaluate(next1())
- except (errors.OutOfRangeError, ValueError) as e:
- # If op1 throws an exception, check that op2 throws same exception.
- with self.assertRaises(type(e)):
- self.evaluate(next2())
- break
- op2 = self.evaluate(next2())
- self.assertAllEqual(op1, op2)
+class CsvDatasetOpTest(test_base.DatasetTestBase):
def _setup_files(self, inputs, linebreak='\n', compression_type=None):
filenames = []
@@ -108,7 +79,7 @@ class CsvDatasetOpTest(test.TestCase):
"""Checks that CsvDataset is equiv to TextLineDataset->map(decode_csv)."""
dataset_actual, dataset_expected = self._make_test_datasets(
inputs, **kwargs)
- self._assert_datasets_equal(dataset_actual, dataset_expected)
+ self.assertDatasetsEqual(dataset_actual, dataset_expected)
def _verify_output_or_err(self,
dataset,
@@ -116,7 +87,7 @@ class CsvDatasetOpTest(test.TestCase):
expected_err_re=None):
if expected_err_re is None:
# Verify that output is expected, without errors
- nxt = self._get_next(dataset)
+ nxt = self.getNext(dataset)
expected_output = [[
v.encode('utf-8') if isinstance(v, str) else v for v in op
] for op in expected_output]
@@ -128,7 +99,7 @@ class CsvDatasetOpTest(test.TestCase):
else:
# Verify that OpError is produced as expected
with self.assertRaisesOpError(expected_err_re):
- nxt = self._get_next(dataset)
+ nxt = self.getNext(dataset)
while True:
try:
self.evaluate(nxt())
@@ -354,7 +325,7 @@ class CsvDatasetOpTest(test.TestCase):
inputs = [['1,,3,4', '5,6,,8']]
ds_actual, ds_expected = self._make_test_datasets(
inputs, record_defaults=record_defaults)
- self._assert_datasets_equal(
+ self.assertDatasetsEqual(
ds_actual.repeat(5).prefetch(1),
ds_expected.repeat(5).prefetch(1))
@@ -377,7 +348,7 @@ class CsvDatasetOpTest(test.TestCase):
ds = readers.make_csv_dataset(
file_path, batch_size=1, shuffle=False, num_epochs=1)
- nxt = self._get_next(ds)
+ nxt = self.getNext(ds)
result = list(self.evaluate(nxt()).values())
diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py
index a2ab3de52e..722e87e555 100644
--- a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.data.python.ops import batching
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
@@ -25,7 +26,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class DatasetConstructorTest(test.TestCase):
+class DatasetConstructorTest(test_base.DatasetTestBase):
def testRestructureDataset(self):
components = (array_ops.placeholder(dtypes.int32),
diff --git a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
index eb110324d1..bc10c21472 100644
--- a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
@@ -20,13 +20,14 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.data.python.ops import interleave_ops
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import random_seed
from tensorflow.python.platform import test
-class DirectedInterleaveDatasetTest(test.TestCase):
+class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
def testBasic(self):
selector_dataset = dataset_ops.Dataset.range(10).repeat(100)
diff --git a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py
index f3968cdc15..cc22ea1df7 100644
--- a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py
@@ -22,6 +22,7 @@ import numpy as np
from tensorflow.contrib.data.python.ops import get_single_element
from tensorflow.contrib.data.python.ops import grouping
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -30,7 +31,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class GetSingleElementTest(test.TestCase, parameterized.TestCase):
+class GetSingleElementTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
("Zero", 0, 1),
diff --git a/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py
index 46a7127b52..d4d3d4adb2 100644
--- a/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import unittest
from tensorflow.contrib.data.python.ops import indexed_dataset_ops
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
@@ -28,7 +29,7 @@ from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
from tensorflow.python.platform import test
-class IndexedDatasetOpsTest(test.TestCase):
+class IndexedDatasetOpsTest(test_base.DatasetTestBase):
def testLowLevelIndexedDatasetOps(self):
identity = ged_ops.experimental_identity_indexed_dataset(
diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
index b9e74dfddb..28bd670ab5 100644
--- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
@@ -25,6 +25,7 @@ import time
from six.moves import zip_longest
from tensorflow.contrib.data.python.ops import interleave_ops
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -36,7 +37,7 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
-class ParallelInterleaveDatasetTest(test.TestCase):
+class ParallelInterleaveDatasetTest(test_base.DatasetTestBase):
def setUp(self):
diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
index 7e2326bd17..58a1d7c93b 100644
--- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.data.python.ops import iterator_ops
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator import model_fn
@@ -33,7 +34,7 @@ from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
-class CheckpointInputPipelineHookTest(test.TestCase):
+class CheckpointInputPipelineHookTest(test_base.DatasetTestBase):
@staticmethod
def _model_fn(features, labels, mode, config):
diff --git a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py
index 1cc5ddc9a2..d2a72272db 100644
--- a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py
@@ -22,6 +22,7 @@ import os
import shutil
from tensorflow.contrib.data.python.ops import readers
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -31,7 +32,7 @@ from tensorflow.python.util import compat
prefix_path = "tensorflow/core/lib"
-class LMDBDatasetTest(test.TestCase):
+class LMDBDatasetTest(test_base.DatasetTestBase):
def setUp(self):
super(LMDBDatasetTest, self).setUp()
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
index e8519381d6..385c4ef6ea 100644
--- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
@@ -29,6 +29,7 @@ from tensorflow.contrib.data.python.ops import error_ops
from tensorflow.contrib.data.python.ops import optimization
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
@@ -41,7 +42,7 @@ from tensorflow.python.util import compat
_NUMPY_RANDOM_SEED = 42
-class MapDatasetTest(test.TestCase):
+class MapDatasetTest(test_base.DatasetTestBase):
def testMapIgnoreError(self):
components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
index 25aea0393f..751e6d5b30 100644
--- a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
@@ -21,6 +21,7 @@ import time
from tensorflow.contrib.data.python.ops import map_defun
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -33,7 +34,8 @@ from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class MapDefunTest(test.TestCase):
+
+class MapDefunTest(test_base.DatasetTestBase):
def testMapDefunSimple(self):
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
index 1ae92bdeff..d7b5edcd9a 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
@@ -15,6 +15,7 @@ py_test(
"//tensorflow/contrib/data/python/ops:optimization",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
],
)
@@ -31,6 +32,7 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:math_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
],
@@ -57,7 +59,6 @@ py_test(
srcs = ["map_vectorization_test.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/contrib/data/python/kernel_tests:test_utils",
"//tensorflow/contrib/data/python/ops:optimization",
"//tensorflow/python:check_ops",
"//tensorflow/python:client_testlib",
@@ -67,6 +68,7 @@ py_test(
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:session",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
@@ -85,6 +87,7 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:math_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
],
@@ -102,6 +105,7 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:math_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
],
@@ -121,6 +125,7 @@ py_test(
"//tensorflow/contrib/data/python/ops:optimization",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -137,6 +142,7 @@ py_test(
"//tensorflow/contrib/data/python/ops:optimization",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -151,6 +157,7 @@ py_test(
"//tensorflow/contrib/data/python/ops:optimization",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py
index d10da80442..fe1b5280ba 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py
@@ -18,12 +18,13 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
-class AssertNextDatasetTest(test.TestCase):
+class AssertNextDatasetTest(test_base.DatasetTestBase):
def testAssertNext(self):
dataset = dataset_ops.Dataset.from_tensors(0).apply(
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py
index 9518c2e1ad..b43efb5c7c 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -31,7 +32,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
-class HoistRandomUniformTest(test.TestCase, parameterized.TestCase):
+class HoistRandomUniformTest(test_base.DatasetTestBase, parameterized.TestCase):
@staticmethod
def map_functions():
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
index e75edf6086..e9e3fc81e5 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -28,7 +29,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase):
+class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
@staticmethod
def map_functions():
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py
index dd547db086..f7907eb890 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -30,7 +31,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
-class MapParallelizationTest(test.TestCase, parameterized.TestCase):
+class MapParallelizationTest(test_base.DatasetTestBase, parameterized.TestCase):
@staticmethod
def map_functions():
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
index 5b493f44c9..a5ea85f454 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
@@ -22,9 +22,9 @@ import time
from absl.testing import parameterized
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests import test_utils
from tensorflow.contrib.data.python.ops import optimization
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -36,7 +36,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
+class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
def _get_test_datasets(self,
base_dataset,
@@ -85,7 +85,7 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
[3, 4]]).repeat(5)
unoptimized, optimized = self._get_test_datasets(base_dataset, map_fn,
num_parallel_calls)
- self._assert_datasets_equal(unoptimized, optimized)
+ self.assertDatasetsEqual(unoptimized, optimized)
def testOptimizationBadMapFn(self):
# Test map functions that give an error
@@ -112,7 +112,7 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
# TODO(rachelim): when this optimization works, turn on expect_optimized
unoptimized, optimized = self._get_test_datasets(
base_dataset, map_fn, expect_optimized=False)
- self._assert_datasets_equal(optimized, unoptimized)
+ self.assertDatasetsEqual(optimized, unoptimized)
def testOptimizationIgnoreStateful(self):
@@ -124,7 +124,7 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
[3, 4]]).repeat(5)
unoptimized, optimized = self._get_test_datasets(
base_dataset, map_fn, expect_optimized=False)
- self._assert_datasets_raise_same_error(
+ self.assertDatasetsRaiseSameError(
unoptimized, optimized, errors.InvalidArgumentError,
[("OneShotIterator", "OneShotIterator_1", 1),
("IteratorGetNext", "IteratorGetNext_1", 1)])
@@ -138,7 +138,7 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
base_dataset = dataset_ops.Dataset.range(20).batch(3, drop_remainder=False)
unoptimized, optimized = self._get_test_datasets(
base_dataset, map_fn, expect_optimized=False)
- self._assert_datasets_equal(unoptimized, optimized)
+ self.assertDatasetsEqual(unoptimized, optimized)
def testOptimizationIgnoreRaggedMap(self):
# Don't optimize when the output of the map fn shapes are unknown.
@@ -148,7 +148,7 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
base_dataset = dataset_ops.Dataset.range(20).batch(1, drop_remainder=True)
unoptimized, optimized = self._get_test_datasets(
base_dataset, map_fn, expect_optimized=False)
- self._assert_datasets_raise_same_error(
+ self.assertDatasetsRaiseSameError(
unoptimized, optimized, errors.InvalidArgumentError,
[("OneShotIterator", "OneShotIterator_1", 1),
("IteratorGetNext", "IteratorGetNext_1", 1)])
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py
index 3b62a7e468..33c250ab2a 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py
@@ -23,12 +23,13 @@ import numpy as np
from tensorflow.contrib.data.python.ops import batching
from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class ModelDatasetTest(test.TestCase):
+class ModelDatasetTest(test_base.DatasetTestBase):
def testModelMap(self):
k = 1024 * 1024
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py
index 507feda3ad..b9e60cfa4e 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -26,7 +27,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class NoopEliminationTest(test.TestCase):
+class NoopEliminationTest(test_base.DatasetTestBase):
def testNoopElimination(self):
a = constant_op.constant(1, dtype=dtypes.int64)
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py
index a3fb824ce9..04f499f8c5 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -28,7 +29,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
-class OptimizeDatasetTest(test.TestCase):
+class OptimizeDatasetTest(test_base.DatasetTestBase):
def testOptimizationDefault(self):
dataset = dataset_ops.Dataset.range(10).apply(
diff --git a/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py
index c4623bca73..66ccaceea5 100644
--- a/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py
@@ -25,6 +25,7 @@ import numpy as np
from tensorflow.contrib.data.python.ops import parsing_ops as contrib_parsing_ops
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
@@ -72,7 +73,7 @@ def _compare_output_to_expected(tester, dict_tensors, expected_tensors,
i += 1
-class ParseExampleTest(test.TestCase):
+class ParseExampleTest(test_base.DatasetTestBase):
def _test(self,
input_tensor,
diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
index 33a64ea767..7a6a7a709a 100644
--- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
@@ -22,6 +22,7 @@ import threading
from tensorflow.contrib.data.python.ops import prefetching_ops
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.compat import compat
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import constant_op
@@ -35,7 +36,7 @@ from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
-class PrefetchingKernelsOpsTest(test.TestCase):
+class PrefetchingKernelsOpsTest(test_base.DatasetTestBase):
def setUp(self):
self._event = threading.Event()
@@ -244,7 +245,7 @@ class PrefetchingKernelsOpsTest(test.TestCase):
sess.run(destroy_op)
-class PrefetchToDeviceTest(test.TestCase):
+class PrefetchToDeviceTest(test_base.DatasetTestBase):
def testPrefetchToDevice(self):
host_dataset = dataset_ops.Dataset.range(10)
@@ -445,7 +446,7 @@ class PrefetchToDeviceTest(test.TestCase):
sess.run(next_element)
-class CopyToDeviceTest(test.TestCase):
+class CopyToDeviceTest(test_base.DatasetTestBase):
def testCopyToDevice(self):
host_dataset = dataset_ops.Dataset.range(10)
diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
index db8fe6aa1b..2e901587f4 100644
--- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
from tensorflow.contrib.data.python.ops import counter
from tensorflow.contrib.data.python.ops import enumerate_ops
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -27,7 +28,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.platform import test
-class RangeDatasetTest(test.TestCase):
+class RangeDatasetTest(test_base.DatasetTestBase):
def testEnumerateDataset(self):
components = (["a", "b"], [1, 2], [37.0, 38])
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
index ed75b27a44..66ed547b6d 100644
--- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
@@ -25,6 +25,7 @@ import numpy as np
from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base
from tensorflow.contrib.data.python.ops import readers
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.data.util import nest
from tensorflow.python.framework import constant_op
@@ -242,7 +243,7 @@ class ReadBatchFeaturesTest(
self.assertEqual(32, shape[0])
-class MakeCsvDatasetTest(test.TestCase):
+class MakeCsvDatasetTest(test_base.DatasetTestBase):
def _make_csv_dataset(self, filenames, batch_size, num_epochs=1, **kwargs):
return readers.make_csv_dataset(
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py
index 08b9f03816..f443b5501b 100644
--- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py
@@ -25,6 +25,7 @@ import zlib
from tensorflow.contrib.data.python.ops import readers
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.framework import constant_op
@@ -32,11 +33,10 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.lib.io import python_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import parsing_ops
-from tensorflow.python.platform import test
from tensorflow.python.util import compat
-class FixedLengthRecordDatasetTestBase(test.TestCase):
+class FixedLengthRecordDatasetTestBase(test_base.DatasetTestBase):
"""Base class for setting up and testing FixedLengthRecordDataset."""
def setUp(self):
@@ -63,7 +63,7 @@ class FixedLengthRecordDatasetTestBase(test.TestCase):
return filenames
-class ReadBatchFeaturesTestBase(test.TestCase):
+class ReadBatchFeaturesTestBase(test_base.DatasetTestBase):
"""Base class for setting up and testing `make_batched_feature_dataset`."""
def setUp(self):
@@ -273,7 +273,7 @@ class ReadBatchFeaturesTestBase(test.TestCase):
self.assertAllEqual(expected_batch[i], actual_batch[i])
-class TextLineDatasetTestBase(test.TestCase):
+class TextLineDatasetTestBase(test_base.DatasetTestBase):
"""Base class for setting up and testing TextLineDataset."""
def _lineText(self, f, l):
@@ -313,7 +313,7 @@ class TextLineDatasetTestBase(test.TestCase):
return filenames
-class TFRecordDatasetTestBase(test.TestCase):
+class TFRecordDatasetTestBase(test_base.DatasetTestBase):
"""Base class for setting up and testing TFRecordDataset."""
def setUp(self):
diff --git a/tensorflow/contrib/data/python/kernel_tests/resample_test.py b/tensorflow/contrib/data/python/kernel_tests/resample_test.py
index 16b1441baa..32474bd411 100644
--- a/tensorflow/contrib/data/python/kernel_tests/resample_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/resample_test.py
@@ -24,6 +24,7 @@ import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.contrib.data.python.ops import resampling
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -57,7 +58,7 @@ def _time_resampling(
return end_time - start_time
-class ResampleTest(test.TestCase, parameterized.TestCase):
+class ResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
("InitialDistributionKnown", True),
diff --git a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
index dde678bd54..bdf80eae4e 100644
--- a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
@@ -22,6 +22,7 @@ import itertools
import numpy as np
from tensorflow.contrib.data.python.ops import scan_ops
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
@@ -33,7 +34,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class ScanDatasetTest(test.TestCase):
+class ScanDatasetTest(test_base.DatasetTestBase):
def _counting_dataset(self, start, scan_fn):
return dataset_ops.Dataset.from_tensors(0).repeat().apply(
diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
index 440e48db30..c97002a255 100644
--- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
@@ -20,13 +20,14 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.data.python.ops import shuffle_ops
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
-class ShuffleAndRepeatTest(test.TestCase):
+class ShuffleAndRepeatTest(test_base.DatasetTestBase):
def _build_ds(self, seed, count=5, num_elements=20):
return dataset_ops.Dataset.range(num_elements).apply(
diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
index 90d18dca2a..c5a7862322 100644
--- a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
@@ -21,6 +21,7 @@ from absl.testing import parameterized
import numpy as np
from tensorflow.contrib.data.python.ops import sliding
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -30,7 +31,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class SlideDatasetTest(test.TestCase, parameterized.TestCase):
+class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
("1", 20, 14, 7, 1),
@@ -197,11 +198,6 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
sliding.sliding_window_batch(
window_size=1, stride=1, window_shift=1, window_stride=1))
- def assertSparseValuesEqual(self, a, b):
- self.assertAllEqual(a.indices, b.indices)
- self.assertAllEqual(a.values, b.values)
- self.assertAllEqual(a.dense_shape, b.dense_shape)
-
def testSlideSparse(self):
def _sparse(i):
diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py
index 1f5c725a92..319a2ea263 100644
--- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py
@@ -24,12 +24,13 @@ import os
import sqlite3
from tensorflow.contrib.data.python.ops import readers
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class SqlDatasetTestBase(test.TestCase):
+class SqlDatasetTestBase(test_base.DatasetTestBase):
"""Base class for setting up and testing SqlDataset."""
def _createSqlDataset(self, output_types, num_repeats=1):
@@ -92,5 +93,3 @@ class SqlDatasetTestBase(test.TestCase):
9007199254740992.0)])
conn.commit()
conn.close()
-
-
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
index b1b4c23510..80f2625927 100644
--- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
@@ -19,10 +19,10 @@ from __future__ import print_function
from tensorflow.core.framework import summary_pb2
-from tensorflow.python.platform import test
+from tensorflow.python.data.kernel_tests import test_base
-class StatsDatasetTestBase(test.TestCase):
+class StatsDatasetTestBase(test_base.DatasetTestBase):
"""Base class for testing statistics gathered in `StatsAggregator`."""
def _assertSummaryContains(self, summary_str, tag):
diff --git a/tensorflow/contrib/data/python/kernel_tests/test_utils.py b/tensorflow/contrib/data/python/kernel_tests/test_utils.py
deleted file mode 100644
index 4c3353fe40..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/test_utils.py
+++ /dev/null
@@ -1,73 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Test utilities for tf.data functionality."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import re
-
-from tensorflow.python.data.util import nest
-from tensorflow.python.framework import errors
-from tensorflow.python.platform import test
-
-
-class DatasetTestBase(test.TestCase):
- """Base class for dataset tests."""
-
- def _assert_datasets_equal(self, dataset1, dataset2):
- # TODO(rachelim): support sparse tensor outputs
- next1 = dataset1.make_one_shot_iterator().get_next()
- next2 = dataset2.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- while True:
- try:
- op1 = sess.run(next1)
- except errors.OutOfRangeError:
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next2)
- break
- op2 = sess.run(next2)
-
- op1 = nest.flatten(op1)
- op2 = nest.flatten(op2)
- assert len(op1) == len(op2)
- for i in range(len(op1)):
- self.assertAllEqual(op1[i], op2[i])
-
- def _assert_datasets_raise_same_error(self,
- dataset1,
- dataset2,
- exception_class,
- replacements=None):
- # We are defining next1 and next2 in the same line so that we get identical
- # file:line_number in the error messages
- # pylint: disable=line-too-long
- next1, next2 = dataset1.make_one_shot_iterator().get_next(), dataset2.make_one_shot_iterator().get_next()
- # pylint: enable=line-too-long
- with self.cached_session() as sess:
- try:
- sess.run(next1)
- raise ValueError(
- "Expected dataset to raise an error of type %s, but it did not." %
- repr(exception_class))
- except exception_class as e:
- expected_message = e.message
- for old, new, count in replacements:
- expected_message = expected_message.replace(old, new, count)
- # Check that the first segment of the error messages are the same.
- with self.assertRaisesRegexp(exception_class,
- re.escape(expected_message)):
- sess.run(next2)
diff --git a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
index 8d335e87d5..08de3a9143 100644
--- a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
@@ -24,6 +24,7 @@ import numpy as np
from tensorflow.contrib.data.python.ops import threadpool
from tensorflow.contrib.data.python.ops import unique
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -31,7 +32,8 @@ from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test
-class OverrideThreadpoolDatasetTest(test.TestCase, parameterized.TestCase):
+class OverrideThreadpoolDatasetTest(test_base.DatasetTestBase,
+ parameterized.TestCase):
@parameterized.named_parameters(
("1", 1, None),
diff --git a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py
index f994c8563f..8856ce5afb 100644
--- a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.data.python.ops import unique
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -25,7 +26,7 @@ from tensorflow.python.platform import test
from tensorflow.python.util import compat
-class UniqueDatasetTest(test.TestCase):
+class UniqueDatasetTest(test_base.DatasetTestBase):
def _testSimpleHelper(self, dtype, test_cases):
"""Test the `unique()` transformation on a list of test cases.
diff --git a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
index 8b7b3ac0f7..79134c7bc6 100644
--- a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
@@ -22,6 +22,7 @@ import numpy as np
from tensorflow.contrib.data.python.ops import batching
from tensorflow.contrib.data.python.ops import grouping
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -31,7 +32,7 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
-class WindowDatasetTest(test.TestCase, parameterized.TestCase):
+class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
def _structuredDataset(self, structure, shape, dtype):
if structure is None:
diff --git a/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py
index 867ee2ba37..fca546a570 100644
--- a/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import os
from tensorflow.contrib.data.python.ops import writers
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
from tensorflow.python.framework import dtypes
@@ -30,7 +31,7 @@ from tensorflow.python.platform import test
from tensorflow.python.util import compat
-class TFRecordWriterTest(test.TestCase):
+class TFRecordWriterTest(test_base.DatasetTestBase):
def setUp(self):
super(TFRecordWriterTest, self).setUp()
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index 5f9818566f..cadfe7f9e0 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -471,6 +471,9 @@ py_library(
srcs = ["test_base.py"],
deps = [
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/util:nest",
],
)
diff --git a/tensorflow/python/data/kernel_tests/test_base.py b/tensorflow/python/data/kernel_tests/test_base.py
index b4f64115b7..b730e10949 100644
--- a/tensorflow/python/data/kernel_tests/test_base.py
+++ b/tensorflow/python/data/kernel_tests/test_base.py
@@ -17,6 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import re
+
+from tensorflow.python.data.util import nest
+from tensorflow.python.eager import context
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import sparse_tensor
from tensorflow.python.platform import test
@@ -24,6 +30,80 @@ class DatasetTestBase(test.TestCase):
"""Base class for dataset tests."""
def assertSparseValuesEqual(self, a, b):
+ """Asserts that two SparseTensors/SparseTensorValues are equal."""
self.assertAllEqual(a.indices, b.indices)
self.assertAllEqual(a.values, b.values)
self.assertAllEqual(a.dense_shape, b.dense_shape)
+
+ def getNext(self, dataset):
+ """Returns a callable that returns the next element of the dataset.
+
+ Example use:
+ ```python
+ # In both graph and eager modes
+ dataset = ...
+ nxt = self.getNext(dataset)
+ result = self.evaluate(nxt())
+ ```
+
+ Args:
+ dataset: A dataset whose next element is returned
+
+ Returns:
+ A callable that returns the next element of `dataset`
+ """
+ it = dataset.make_one_shot_iterator()
+ if context.executing_eagerly():
+ return it.get_next
+ else:
+ nxt = it.get_next()
+ return lambda: nxt
+
+ def assertDatasetsEqual(self, dataset1, dataset2):
+ """Checks that datasets are equal. Supports both graph and eager mode."""
+ self.assertEqual(dataset1.output_types, dataset2.output_types)
+ self.assertEqual(dataset1.output_classes, dataset2.output_classes)
+
+ next1 = self.getNext(dataset1)
+ next2 = self.getNext(dataset2)
+ while True:
+ try:
+ op1 = self.evaluate(next1())
+ except errors.OutOfRangeError:
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(next2())
+ break
+ op2 = self.evaluate(next2())
+
+ op1 = nest.flatten(op1)
+ op2 = nest.flatten(op2)
+ assert len(op1) == len(op2)
+ for i in range(len(op1)):
+ if isinstance(
+ op1[i],
+ (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
+ self.assertSparseValuesEqual(op1[i], op2[i])
+ else:
+ self.assertAllEqual(op1[i], op2[i])
+
+ def assertDatasetsRaiseSameError(self,
+ dataset1,
+ dataset2,
+ exception_class,
+ replacements=None):
+ """Checks that datasets raise the same error on the first get_next call."""
+ next1 = self.getNext(dataset1)
+ next2 = self.getNext(dataset2)
+ try:
+ self.evaluate(next1())
+ raise ValueError(
+ 'Expected dataset to raise an error of type %s, but it did not.' %
+ repr(exception_class))
+ except exception_class as e:
+ expected_message = e.message
+ for old, new, count in replacements:
+ expected_message = expected_message.replace(old, new, count)
+ # Check that the first segment of the error messages are the same.
+ with self.assertRaisesRegexp(exception_class,
+ re.escape(expected_message)):
+ self.evaluate(next2())
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index 7d925a8fef..c621812535 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -66,7 +66,6 @@ COMMON_PIP_DEPS = [
"//tensorflow/contrib/constrained_optimization:constrained_optimization_pip",
"//tensorflow/contrib/data/python/kernel_tests/serialization:dataset_serialization_test_base",
"//tensorflow/contrib/data/python/kernel_tests:stats_dataset_test_base",
- "//tensorflow/contrib/data/python/kernel_tests:test_utils",
"//tensorflow/contrib/eager/python/examples:examples_pip",
"//tensorflow/contrib/eager/python:evaluator",
"//tensorflow/contrib/gan:gan",