diff options
Diffstat (limited to 'tensorflow/python/data/kernel_tests/test_base.py')
-rw-r--r-- | tensorflow/python/data/kernel_tests/test_base.py | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/tensorflow/python/data/kernel_tests/test_base.py b/tensorflow/python/data/kernel_tests/test_base.py index b730e10949..b73a94e683 100644 --- a/tensorflow/python/data/kernel_tests/test_base.py +++ b/tensorflow/python/data/kernel_tests/test_base.py @@ -19,10 +19,13 @@ from __future__ import print_function import re +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -107,3 +110,29 @@ class DatasetTestBase(test.TestCase): with self.assertRaisesRegexp(exception_class, re.escape(expected_message)): self.evaluate(next2()) + + def structuredDataset(self, structure, shape=None, dtype=dtypes.int64): + """Returns a singleton dataset with the given structure.""" + if shape is None: + shape = [] + if structure is None: + return dataset_ops.Dataset.from_tensors( + array_ops.zeros(shape, dtype=dtype)) + else: + return dataset_ops.Dataset.zip( + tuple([ + self.structuredDataset(substructure, shape, dtype) + for substructure in structure + ])) + + def structuredElement(self, structure, shape=None, dtype=dtypes.int64): + """Returns an element with the given structure.""" + if shape is None: + shape = [] + if structure is None: + return array_ops.zeros(shape, dtype=dtype) + else: + return tuple([ + self.structuredElement(substructure, shape, dtype) + for substructure in structure + ]) |