aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/kernels/cast_op.cc5
-rw-r--r--tensorflow/core/kernels/cast_op_gpu.cu.cc2
-rw-r--r--tensorflow/core/kernels/cast_op_test.cc2
3 files changed, 9 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/cast_op.cc b/tensorflow/core/kernels/cast_op.cc
index d88b948811..938b0f5ef5 100644
--- a/tensorflow/core/kernels/cast_op.cc
+++ b/tensorflow/core/kernels/cast_op.cc
@@ -60,6 +60,7 @@ struct CastFunctor<CPUDevice, O, I> {
FN(arg0, bool); \
FN(arg0, uint8); \
FN(arg0, int8); \
+ FN(arg0, uint16); \
FN(arg0, int16); \
FN(arg0, int32); \
FN(arg0, int64); \
@@ -70,6 +71,7 @@ struct CastFunctor<CPUDevice, O, I> {
FN(arg0, arg1, bool); \
FN(arg0, arg1, uint8); \
FN(arg0, arg1, int8); \
+ FN(arg0, arg1, uint16); \
FN(arg0, arg1, int16); \
FN(arg0, arg1, int32); \
FN(arg0, arg1, int64); \
@@ -134,6 +136,7 @@ class CpuCastOp : public CastOpBase {
CURRY_TYPES3(CAST_CASE, CPUDevice, bool);
CURRY_TYPES3(CAST_CASE, CPUDevice, uint8);
CURRY_TYPES3(CAST_CASE, CPUDevice, int8);
+ CURRY_TYPES3(CAST_CASE, CPUDevice, uint16);
CURRY_TYPES3(CAST_CASE, CPUDevice, int16);
CURRY_TYPES3(CAST_CASE, CPUDevice, int32);
CURRY_TYPES3(CAST_CASE, CPUDevice, int64);
@@ -197,6 +200,7 @@ class GpuCastOp : public CastOpBase {
CURRY_TYPES3(CAST_CASE, GPUDevice, bool);
CURRY_TYPES3(CAST_CASE, GPUDevice, uint8);
CURRY_TYPES3(CAST_CASE, GPUDevice, int8);
+ CURRY_TYPES3(CAST_CASE, GPUDevice, uint16);
CURRY_TYPES3(CAST_CASE, GPUDevice, int16);
CURRY_TYPES3(CAST_CASE, GPUDevice, int32);
CURRY_TYPES3(CAST_CASE, GPUDevice, int64);
@@ -223,6 +227,7 @@ REGISTER_KERNEL_BUILDER(Name("Cast").Device(DEVICE_CPU), CpuCastOp);
CURRY_TYPES2(REGISTER_CAST_GPU, bool);
CURRY_TYPES2(REGISTER_CAST_GPU, uint8);
CURRY_TYPES2(REGISTER_CAST_GPU, int8);
+CURRY_TYPES2(REGISTER_CAST_GPU, uint16);
CURRY_TYPES2(REGISTER_CAST_GPU, int16);
CURRY_TYPES2(REGISTER_CAST_GPU, int32);
CURRY_TYPES2(REGISTER_CAST_GPU, int64);
diff --git a/tensorflow/core/kernels/cast_op_gpu.cu.cc b/tensorflow/core/kernels/cast_op_gpu.cu.cc
index 03958d1e37..75d1785eda 100644
--- a/tensorflow/core/kernels/cast_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/cast_op_gpu.cu.cc
@@ -38,6 +38,7 @@ struct CastFunctor<GPUDevice, O, I> {
DEFINE(in_type, bool); \
DEFINE(in_type, uint8); \
DEFINE(in_type, int8); \
+ DEFINE(in_type, uint16); \
DEFINE(in_type, int16); \
DEFINE(in_type, int32); \
DEFINE(in_type, int64); \
@@ -47,6 +48,7 @@ struct CastFunctor<GPUDevice, O, I> {
DEFINE_ALL_FROM(bool);
DEFINE_ALL_FROM(uint8);
DEFINE_ALL_FROM(int8);
+DEFINE_ALL_FROM(uint16);
DEFINE_ALL_FROM(int16);
DEFINE_ALL_FROM(int32);
DEFINE_ALL_FROM(int64);
diff --git a/tensorflow/core/kernels/cast_op_test.cc b/tensorflow/core/kernels/cast_op_test.cc
index 50826a040d..23a3f77d00 100644
--- a/tensorflow/core/kernels/cast_op_test.cc
+++ b/tensorflow/core/kernels/cast_op_test.cc
@@ -66,6 +66,7 @@ class CastOpTest : public OpsTestBase {
#define TEST_ALL_CASTS_FROM(in) \
TEST_CAST(in, uint8); \
+ TEST_CAST(in, uint16); \
TEST_CAST(in, int16); \
TEST_CAST(in, int32); \
TEST_CAST(in, int64); \
@@ -73,6 +74,7 @@ class CastOpTest : public OpsTestBase {
TEST_CAST(in, double)
TEST_ALL_CASTS_FROM(uint8)
+TEST_ALL_CASTS_FROM(uint16)
TEST_ALL_CASTS_FROM(int16)
TEST_ALL_CASTS_FROM(int32)
TEST_ALL_CASTS_FROM(int64)