aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-05-16 21:15:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-16 21:17:59 -0700
commit147a31e3fc2505e467cdd781019f20e0d6aa1a58 (patch)
tree9b718b681c7d12c38c5709f6efea740a3fccca54
parentdeca317a9c8b4567cccc3270fc63065dbbe23c69 (diff)
[tf.data] Accept NumPy dtype objects in `Dataset.from_generator(..., output_types=...)`.
PiperOrigin-RevId: 196935179
-rw-r--r--tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py11
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py2
2 files changed, 8 insertions, 5 deletions
diff --git a/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py b/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py
index 9fcdf1b062..296a76ec88 100644
--- a/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py
+++ b/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py
@@ -32,9 +32,12 @@ from tensorflow.python.platform import test
class DatasetConstructorTest(test.TestCase):
- def _testFromGenerator(self, generator, elem_sequence, num_repeats):
+ def _testFromGenerator(self, generator, elem_sequence, num_repeats,
+ output_types=None):
+ if output_types is None:
+ output_types = dtypes.int64
iterator = (
- dataset_ops.Dataset.from_generator(generator, output_types=dtypes.int64)
+ dataset_ops.Dataset.from_generator(generator, output_types=output_types)
.repeat(num_repeats)
.prefetch(5)
.make_initializable_iterator())
@@ -84,8 +87,8 @@ class DatasetConstructorTest(test.TestCase):
def testFromGeneratorUsingNdarray(self):
generator = lambda: np.arange(100, dtype=np.int64)
elem_sequence = list(generator())
- self._testFromGenerator(generator, elem_sequence, 1)
- self._testFromGenerator(generator, elem_sequence, 5)
+ self._testFromGenerator(generator, elem_sequence, 1, output_types=np.int64)
+ self._testFromGenerator(generator, elem_sequence, 5, output_types=np.int64)
def testFromGeneratorUsingGeneratorExpression(self):
# NOTE(mrry): Generator *expressions* are not repeatable (or in
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 8b3c2facbc..6a3f6bf40c 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -354,7 +354,7 @@ class Dataset(object):
else:
args = tuple(ops.convert_n_to_tensor(args, name="args"))
- flattened_types = nest.flatten(output_types)
+ flattened_types = [dtypes.as_dtype(dt) for dt in nest.flatten(output_types)]
flattened_shapes = nest.flatten(output_shapes)
generator_state = Dataset._GeneratorState(generator)