aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-06-22 17:21:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-22 17:27:24 -0700
commitb2b89083ae7f2da52ba1310f8224a46a9f64a437 (patch)
tree063991ca94461ddbaa3a51bbb68e53452b536610
parentc6c4116931a42dfafbcece3c4a4791c22120ed3b (diff)
[tf.data] Adding support for `DT_BOOL` for `tf.contrib.data.map_and_batch`.
PiperOrigin-RevId: 201765214
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py26
-rw-r--r--tensorflow/core/kernels/data/map_and_batch_dataset_op.cc8
-rw-r--r--tensorflow/core/kernels/inplace_ops.cc2
3 files changed, 30 insertions, 6 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
index d2e14f5fd7..af97fbf87a 100644
--- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
@@ -689,6 +689,32 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
+ @parameterized.parameters(
+ (False, dtypes.bool),
+ (-42, dtypes.int8),
+ (-42, dtypes.int16),
+ (-42, dtypes.int32),
+ (-42, dtypes.int64),
+ (42, dtypes.uint8),
+ (42, dtypes.uint16),
+ (42.0, dtypes.float16),
+ (42.0, dtypes.float32),
+ (42.0, dtypes.float64),
+ (b"hello", dtypes.string),
+ )
+ def testMapAndBatchTypes(self, element, dtype):
+ def gen():
+ yield element
+
+ dataset = dataset_ops.Dataset.from_generator(gen, dtype).repeat(100).apply(
+ batching.map_and_batch(lambda x: x, batch_size=10))
+
+ get_next = dataset.make_one_shot_iterator().get_next()
+
+ with self.test_session() as sess:
+ for _ in range(10):
+ self.assertAllEqual([element for _ in range(10)], sess.run(get_next))
+
class RestructuredDatasetTest(test.TestCase):
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
index 002e0afcc2..004f153af6 100644
--- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
@@ -370,7 +370,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
Status CopyPartialBatch(Tensor* output, const Tensor& value,
int64 num_elements) {
switch (value.dtype()) {
-#define CASE(type) \
+#define HANDLE_TYPE(type) \
case DataTypeToEnum<type>::value: { \
auto output_t = output->flat_outer_dims<type>(); \
auto value_t = value.flat_outer_dims<type>(); \
@@ -379,10 +379,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
} \
return Status::OK(); \
}
- TF_CALL_NUMBER_TYPES(CASE);
- TF_CALL_string(CASE);
- TF_CALL_variant(CASE);
-#undef CASE
+ TF_CALL_DATASET_TYPES(HANDLE_TYPE);
+#undef HANDLE_TYPE
default:
return errors::InvalidArgument("Unsupported data type: ",
value.dtype());
diff --git a/tensorflow/core/kernels/inplace_ops.cc b/tensorflow/core/kernels/inplace_ops.cc
index 8f51cc3819..8ddf3c38e8 100644
--- a/tensorflow/core/kernels/inplace_ops.cc
+++ b/tensorflow/core/kernels/inplace_ops.cc
@@ -50,7 +50,7 @@ Status DoParallelConcat(const CPUDevice& d, const Tensor& value, int32 loc,
#define CASE(type) \
case DataTypeToEnum<type>::value: \
return DoParallelConcatUpdate<CPUDevice, type>(d, value, loc, output);
- TF_CALL_NUMBER_TYPES(CASE);
+ TF_CALL_POD_TYPES(CASE);
TF_CALL_string(CASE);
TF_CALL_variant(CASE);
#undef CASE