diff options
-rw-r--r-- | tensorflow/core/kernels/cast_op.cc | 5 | ||||
-rw-r--r-- | tensorflow/core/kernels/cast_op_gpu.cu.cc | 2 | ||||
-rw-r--r-- | tensorflow/core/kernels/cast_op_test.cc | 2 |
3 files changed, 9 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/cast_op.cc b/tensorflow/core/kernels/cast_op.cc index d88b948811..938b0f5ef5 100644 --- a/tensorflow/core/kernels/cast_op.cc +++ b/tensorflow/core/kernels/cast_op.cc @@ -60,6 +60,7 @@ struct CastFunctor<CPUDevice, O, I> { FN(arg0, bool); \ FN(arg0, uint8); \ FN(arg0, int8); \ + FN(arg0, uint16); \ FN(arg0, int16); \ FN(arg0, int32); \ FN(arg0, int64); \ @@ -70,6 +71,7 @@ struct CastFunctor<CPUDevice, O, I> { FN(arg0, arg1, bool); \ FN(arg0, arg1, uint8); \ FN(arg0, arg1, int8); \ + FN(arg0, arg1, uint16); \ FN(arg0, arg1, int16); \ FN(arg0, arg1, int32); \ FN(arg0, arg1, int64); \ @@ -134,6 +136,7 @@ class CpuCastOp : public CastOpBase { CURRY_TYPES3(CAST_CASE, CPUDevice, bool); CURRY_TYPES3(CAST_CASE, CPUDevice, uint8); CURRY_TYPES3(CAST_CASE, CPUDevice, int8); + CURRY_TYPES3(CAST_CASE, CPUDevice, uint16); CURRY_TYPES3(CAST_CASE, CPUDevice, int16); CURRY_TYPES3(CAST_CASE, CPUDevice, int32); CURRY_TYPES3(CAST_CASE, CPUDevice, int64); @@ -197,6 +200,7 @@ class GpuCastOp : public CastOpBase { CURRY_TYPES3(CAST_CASE, GPUDevice, bool); CURRY_TYPES3(CAST_CASE, GPUDevice, uint8); CURRY_TYPES3(CAST_CASE, GPUDevice, int8); + CURRY_TYPES3(CAST_CASE, GPUDevice, uint16); CURRY_TYPES3(CAST_CASE, GPUDevice, int16); CURRY_TYPES3(CAST_CASE, GPUDevice, int32); CURRY_TYPES3(CAST_CASE, GPUDevice, int64); @@ -223,6 +227,7 @@ REGISTER_KERNEL_BUILDER(Name("Cast").Device(DEVICE_CPU), CpuCastOp); CURRY_TYPES2(REGISTER_CAST_GPU, bool); CURRY_TYPES2(REGISTER_CAST_GPU, uint8); CURRY_TYPES2(REGISTER_CAST_GPU, int8); +CURRY_TYPES2(REGISTER_CAST_GPU, uint16); CURRY_TYPES2(REGISTER_CAST_GPU, int16); CURRY_TYPES2(REGISTER_CAST_GPU, int32); CURRY_TYPES2(REGISTER_CAST_GPU, int64); diff --git a/tensorflow/core/kernels/cast_op_gpu.cu.cc b/tensorflow/core/kernels/cast_op_gpu.cu.cc index 03958d1e37..75d1785eda 100644 --- a/tensorflow/core/kernels/cast_op_gpu.cu.cc +++ b/tensorflow/core/kernels/cast_op_gpu.cu.cc @@ -38,6 +38,7 @@ struct CastFunctor<GPUDevice, O, I> { DEFINE(in_type, bool); \ DEFINE(in_type, uint8); \ DEFINE(in_type, int8); \ + DEFINE(in_type, uint16); \ DEFINE(in_type, int16); \ DEFINE(in_type, int32); \ DEFINE(in_type, int64); \ @@ -47,6 +48,7 @@ struct CastFunctor<GPUDevice, O, I> { DEFINE_ALL_FROM(bool); DEFINE_ALL_FROM(uint8); DEFINE_ALL_FROM(int8); +DEFINE_ALL_FROM(uint16); DEFINE_ALL_FROM(int16); DEFINE_ALL_FROM(int32); DEFINE_ALL_FROM(int64); diff --git a/tensorflow/core/kernels/cast_op_test.cc b/tensorflow/core/kernels/cast_op_test.cc index 50826a040d..23a3f77d00 100644 --- a/tensorflow/core/kernels/cast_op_test.cc +++ b/tensorflow/core/kernels/cast_op_test.cc @@ -66,6 +66,7 @@ class CastOpTest : public OpsTestBase { #define TEST_ALL_CASTS_FROM(in) \ TEST_CAST(in, uint8); \ + TEST_CAST(in, uint16); \ TEST_CAST(in, int16); \ TEST_CAST(in, int32); \ TEST_CAST(in, int64); \ @@ -73,6 +74,7 @@ class CastOpTest : public OpsTestBase { TEST_CAST(in, double) TEST_ALL_CASTS_FROM(uint8) +TEST_ALL_CASTS_FROM(uint16) TEST_ALL_CASTS_FROM(int16) TEST_ALL_CASTS_FROM(int32) TEST_ALL_CASTS_FROM(int64) |