From b2b89083ae7f2da52ba1310f8224a46a9f64a437 Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Fri, 22 Jun 2018 17:21:52 -0700 Subject: [tf.data] Adding support for `DT_BOOL` for `tf.contrib.data.map_and_batch`. PiperOrigin-RevId: 201765214 --- .../python/kernel_tests/batch_dataset_op_test.py | 26 ++++++++++++++++++++++ .../core/kernels/data/map_and_batch_dataset_op.cc | 8 +++---- tensorflow/core/kernels/inplace_ops.cc | 2 +- 3 files changed, 30 insertions(+), 6 deletions(-) diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index d2e14f5fd7..af97fbf87a 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -689,6 +689,32 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + @parameterized.parameters( + (False, dtypes.bool), + (-42, dtypes.int8), + (-42, dtypes.int16), + (-42, dtypes.int32), + (-42, dtypes.int64), + (42, dtypes.uint8), + (42, dtypes.uint16), + (42.0, dtypes.float16), + (42.0, dtypes.float32), + (42.0, dtypes.float64), + (b"hello", dtypes.string), + ) + def testMapAndBatchTypes(self, element, dtype): + def gen(): + yield element + + dataset = dataset_ops.Dataset.from_generator(gen, dtype).repeat(100).apply( + batching.map_and_batch(lambda x: x, batch_size=10)) + + get_next = dataset.make_one_shot_iterator().get_next() + + with self.test_session() as sess: + for _ in range(10): + self.assertAllEqual([element for _ in range(10)], sess.run(get_next)) + class RestructuredDatasetTest(test.TestCase): diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc index 002e0afcc2..004f153af6 100644 --- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc @@ -370,7 +370,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { Status CopyPartialBatch(Tensor* output, const Tensor& value, int64 num_elements) { switch (value.dtype()) { -#define CASE(type) \ +#define HANDLE_TYPE(type) \ case DataTypeToEnum::value: { \ auto output_t = output->flat_outer_dims(); \ auto value_t = value.flat_outer_dims(); \ @@ -379,10 +379,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { } \ return Status::OK(); \ } - TF_CALL_NUMBER_TYPES(CASE); - TF_CALL_string(CASE); - TF_CALL_variant(CASE); -#undef CASE + TF_CALL_DATASET_TYPES(HANDLE_TYPE); +#undef HANDLE_TYPE default: return errors::InvalidArgument("Unsupported data type: ", value.dtype()); diff --git a/tensorflow/core/kernels/inplace_ops.cc b/tensorflow/core/kernels/inplace_ops.cc index 8f51cc3819..8ddf3c38e8 100644 --- a/tensorflow/core/kernels/inplace_ops.cc +++ b/tensorflow/core/kernels/inplace_ops.cc @@ -50,7 +50,7 @@ Status DoParallelConcat(const CPUDevice& d, const Tensor& value, int32 loc, #define CASE(type) \ case DataTypeToEnum::value: \ return DoParallelConcatUpdate(d, value, loc, output); - TF_CALL_NUMBER_TYPES(CASE); + TF_CALL_POD_TYPES(CASE); TF_CALL_string(CASE); TF_CALL_variant(CASE); #undef CASE -- cgit v1.2.3