aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/data/kernel_tests/test_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/data/kernel_tests/test_base.py')
-rw-r--r--tensorflow/python/data/kernel_tests/test_base.py29
1 files changed, 29 insertions, 0 deletions
diff --git a/tensorflow/python/data/kernel_tests/test_base.py b/tensorflow/python/data/kernel_tests/test_base.py
index b730e10949..b73a94e683 100644
--- a/tensorflow/python/data/kernel_tests/test_base.py
+++ b/tensorflow/python/data/kernel_tests/test_base.py
@@ -19,10 +19,13 @@ from __future__ import print_function
import re
+from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.eager import context
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@@ -107,3 +110,29 @@ class DatasetTestBase(test.TestCase):
with self.assertRaisesRegexp(exception_class,
re.escape(expected_message)):
self.evaluate(next2())
+
+ def structuredDataset(self, structure, shape=None, dtype=dtypes.int64):
+ """Returns a singleton dataset with the given structure."""
+ if shape is None:
+ shape = []
+ if structure is None:
+ return dataset_ops.Dataset.from_tensors(
+ array_ops.zeros(shape, dtype=dtype))
+ else:
+ return dataset_ops.Dataset.zip(
+ tuple([
+ self.structuredDataset(substructure, shape, dtype)
+ for substructure in structure
+ ]))
+
+ def structuredElement(self, structure, shape=None, dtype=dtypes.int64):
+ """Returns an element with the given structure."""
+ if shape is None:
+ shape = []
+ if structure is None:
+ return array_ops.zeros(shape, dtype=dtype)
+ else:
+ return tuple([
+ self.structuredElement(substructure, shape, dtype)
+ for substructure in structure
+ ])