diff options
author | 2018-05-16 21:15:44 -0700 | |
---|---|---|
committer | 2018-05-16 21:17:59 -0700 | |
commit | 147a31e3fc2505e467cdd781019f20e0d6aa1a58 (patch) | |
tree | 9b718b681c7d12c38c5709f6efea740a3fccca54 | |
parent | deca317a9c8b4567cccc3270fc63065dbbe23c69 (diff) |
[tf.data] Accept NumPy dtype objects in `Dataset.from_generator(..., output_types=...)`.
PiperOrigin-RevId: 196935179
-rw-r--r-- | tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py | 11 | ||||
-rw-r--r-- | tensorflow/python/data/ops/dataset_ops.py | 2 |
2 files changed, 8 insertions, 5 deletions
diff --git a/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py b/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py index 9fcdf1b062..296a76ec88 100644 --- a/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py +++ b/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py @@ -32,9 +32,12 @@ from tensorflow.python.platform import test class DatasetConstructorTest(test.TestCase): - def _testFromGenerator(self, generator, elem_sequence, num_repeats): + def _testFromGenerator(self, generator, elem_sequence, num_repeats, + output_types=None): + if output_types is None: + output_types = dtypes.int64 iterator = ( - dataset_ops.Dataset.from_generator(generator, output_types=dtypes.int64) + dataset_ops.Dataset.from_generator(generator, output_types=output_types) .repeat(num_repeats) .prefetch(5) .make_initializable_iterator()) @@ -84,8 +87,8 @@ class DatasetConstructorTest(test.TestCase): def testFromGeneratorUsingNdarray(self): generator = lambda: np.arange(100, dtype=np.int64) elem_sequence = list(generator()) - self._testFromGenerator(generator, elem_sequence, 1) - self._testFromGenerator(generator, elem_sequence, 5) + self._testFromGenerator(generator, elem_sequence, 1, output_types=np.int64) + self._testFromGenerator(generator, elem_sequence, 5, output_types=np.int64) def testFromGeneratorUsingGeneratorExpression(self): # NOTE(mrry): Generator *expressions* are not repeatable (or in diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 8b3c2facbc..6a3f6bf40c 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -354,7 +354,7 @@ class Dataset(object): else: args = tuple(ops.convert_n_to_tensor(args, name="args")) - flattened_types = nest.flatten(output_types) + flattened_types = [dtypes.as_dtype(dt) for dt in nest.flatten(output_types)] flattened_shapes = nest.flatten(output_shapes) generator_state = Dataset._GeneratorState(generator) |