aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu/python/tpu/datasets_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tpu/python/tpu/datasets_test.py')
-rw-r--r--tensorflow/contrib/tpu/python/tpu/datasets_test.py26
1 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/datasets_test.py b/tensorflow/contrib/tpu/python/tpu/datasets_test.py
index 918cf0ed8e..b58d05eac5 100644
--- a/tensorflow/contrib/tpu/python/tpu/datasets_test.py
+++ b/tensorflow/contrib/tpu/python/tpu/datasets_test.py
@@ -26,6 +26,8 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.lib.io import python_io
from tensorflow.python.platform import test
from tensorflow.python.training import server_lib
@@ -162,6 +164,30 @@ class DatasetsTest(test.TestCase):
self.assertEqual(set(all_contents), set(retrieved_values))
+ def testArbitraryReaderFuncFromDatasetGenerator(self):
+
+ def my_generator():
+ yield (1, [1] * 10)
+
+ def gen_dataset(dummy):
+ return dataset_ops.Dataset.from_generator(
+ my_generator, (dtypes.int64, dtypes.int64),
+ (tensor_shape.TensorShape([]), tensor_shape.TensorShape([10])))
+
+ dataset = datasets.StreamingFilesDataset(
+ dataset_ops.Dataset.range(10), filetype=gen_dataset)
+
+ iterator = dataset.make_initializable_iterator()
+ self._sess.run(iterator.initializer)
+ get_next = iterator.get_next()
+
+ retrieved_values = self._sess.run(get_next)
+
+ self.assertIsInstance(retrieved_values, (list, tuple))
+ self.assertEqual(len(retrieved_values), 2)
+ self.assertEqual(retrieved_values[0], 1)
+ self.assertItemsEqual(retrieved_values[1], [1] * 10)
+
def testUnexpectedFiletypeString(self):
with self.assertRaises(ValueError):
datasets.StreamingFilesDataset(