aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/batchtospace_op.cc
diff options
context:
space:
mode:
authorGravatar Yong Tang <yong.tang.github@outlook.com>2017-10-20 16:36:31 -0700
committerGravatar Vijay Vasudevan <vrv@google.com>2017-10-20 16:36:31 -0700
commitc77090a0ae61fc69fcdff7c58be9feb6121e3bd4 (patch)
treee0418715da60ca996cd15011d5a06c5cd1da8c50 /tensorflow/core/kernels/batchtospace_op.cc
parent9c825d32c9423980e1b263a50360e03e833b69a6 (diff)
Fix issues where int64 crops could not be passed to batch_to_space. (#13862)
* Fix issues where int64 crops could not be passed to batch_to_space. This fix tries to address the issue where int64 `crops` could not be passed to `batch_to_space` even though both int32 and int64 are specified as supported in the docs (tf.batch_to_space.__doc__) The reason is that BatchToSpace kernel puts a constraint of int32 to crops data types. This fix removed the constraint so that int64 `crops` could be supported. NOTE: Just removing the constraint should work and it is not necessary to add specification to the kernel class template, as `SubtleMustCopyFlat` called in the class already correctly handled both int32 and int64 cases. Besides, other data types (e.g., float or double) will not be passed to the kernel as they are guarded by the specification in `array_ops.cc`. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Also remove int64/int32 type constraints for SpaceToBatch kernels Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add test cases for int64 crops of batch_to_space and space_to_batch Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Fix test failures. Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Diffstat (limited to 'tensorflow/core/kernels/batchtospace_op.cc')
-rw-r--r--tensorflow/core/kernels/batchtospace_op.cc50
1 files changed, 22 insertions, 28 deletions
diff --git a/tensorflow/core/kernels/batchtospace_op.cc b/tensorflow/core/kernels/batchtospace_op.cc
index 99b5d3daaa..c1c0d6d329 100644
--- a/tensorflow/core/kernels/batchtospace_op.cc
+++ b/tensorflow/core/kernels/batchtospace_op.cc
@@ -249,40 +249,34 @@ class BatchToSpaceOp : public OpKernel {
Tensor block_shape_;
};
-#define REGISTER(T) \
- REGISTER_KERNEL_BUILDER(Name("BatchToSpaceND") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T") \
- .TypeConstraint<int32>("Tblock_shape") \
- .TypeConstraint<int32>("Tcrops") \
- .HostMemory("block_shape") \
- .HostMemory("crops"), \
- BatchToSpaceNDOp<CPUDevice, T>); \
- REGISTER_KERNEL_BUILDER(Name("BatchToSpace") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T") \
- .TypeConstraint<int32>("Tidx") \
- .HostMemory("crops"), \
+#define REGISTER(T) \
+ REGISTER_KERNEL_BUILDER(Name("BatchToSpaceND") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .HostMemory("block_shape") \
+ .HostMemory("crops"), \
+ BatchToSpaceNDOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("BatchToSpace") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .HostMemory("crops"), \
BatchToSpaceOp<CPUDevice, T>);
TF_CALL_REAL_NUMBER_TYPES(REGISTER);
#undef REGISTER
#if GOOGLE_CUDA
-#define REGISTER(T) \
- REGISTER_KERNEL_BUILDER(Name("BatchToSpaceND") \
- .Device(DEVICE_GPU) \
- .TypeConstraint<T>("T") \
- .TypeConstraint<int32>("Tblock_shape") \
- .TypeConstraint<int32>("Tcrops") \
- .HostMemory("block_shape") \
- .HostMemory("crops"), \
- BatchToSpaceNDOp<GPUDevice, T>); \
- REGISTER_KERNEL_BUILDER(Name("BatchToSpace") \
- .Device(DEVICE_GPU) \
- .TypeConstraint<T>("T") \
- .TypeConstraint<int32>("Tidx") \
- .HostMemory("crops"), \
+#define REGISTER(T) \
+ REGISTER_KERNEL_BUILDER(Name("BatchToSpaceND") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<T>("T") \
+ .HostMemory("block_shape") \
+ .HostMemory("crops"), \
+ BatchToSpaceNDOp<GPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("BatchToSpace") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<T>("T") \
+ .HostMemory("crops"), \
BatchToSpaceOp<GPUDevice, T>);
TF_CALL_GPU_NUMBER_TYPES(REGISTER);