aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/cast_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/cast_op.cc')
-rw-r--r--tensorflow/core/kernels/cast_op.cc80
1 files changed, 65 insertions, 15 deletions
diff --git a/tensorflow/core/kernels/cast_op.cc b/tensorflow/core/kernels/cast_op.cc
index 626db9131a..0478c93280 100644
--- a/tensorflow/core/kernels/cast_op.cc
+++ b/tensorflow/core/kernels/cast_op.cc
@@ -41,8 +41,10 @@ typedef Eigen::SyclDevice SYCLDevice;
#define CURRY_TYPES2(FN, arg0) \
FN(arg0, bool); \
FN(arg0, uint8); \
- FN(arg0, int8); \
FN(arg0, uint16); \
+ FN(arg0, uint32); \
+ FN(arg0, uint64); \
+ FN(arg0, int8); \
FN(arg0, int16); \
FN(arg0, int32); \
FN(arg0, int64); \
@@ -53,8 +55,41 @@ typedef Eigen::SyclDevice SYCLDevice;
FN(arg0, std::complex<double>)
CastOpBase::CastOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("SrcT", &src_dtype_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT", &dst_dtype_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("SrcT", &external_src_dtype_));
+
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT", &external_dst_dtype_));
+
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("Truncate", &use_truncation_));
+
+ // Quantized data types use the same underlying format as their non quantized
+ // version so we use the non quantized implementation for casting.
+ if (external_dst_dtype_ == DT_QUINT8) {
+ dst_dtype_ = DT_UINT8;
+ } else if (external_dst_dtype_ == DT_QINT8) {
+ dst_dtype_ = DT_INT8;
+ } else if (external_dst_dtype_ == DT_QINT32) {
+ dst_dtype_ = DT_INT32;
+ } else if (external_dst_dtype_ == DT_QINT16) {
+ dst_dtype_ = DT_INT16;
+ } else if (external_dst_dtype_ == DT_QUINT16) {
+ dst_dtype_ = DT_UINT16;
+ } else {
+ dst_dtype_ = external_dst_dtype_;
+ }
+
+ if (external_src_dtype_ == DT_QUINT8) {
+ src_dtype_ = DT_UINT8;
+ } else if (external_src_dtype_ == DT_QINT8) {
+ src_dtype_ = DT_INT8;
+ } else if (external_src_dtype_ == DT_QINT32) {
+ src_dtype_ = DT_INT32;
+ } else if (external_src_dtype_ == DT_QINT16) {
+ src_dtype_ = DT_INT16;
+ } else if (external_src_dtype_ == DT_QUINT16) {
+ src_dtype_ = DT_UINT16;
+ } else {
+ src_dtype_ = external_src_dtype_;
+ }
}
void CastOpBase::Compute(OpKernelContext* ctx) {
@@ -62,15 +97,20 @@ void CastOpBase::Compute(OpKernelContext* ctx) {
if (work_ == nullptr) {
ctx->set_output(0, inp);
} else {
+ Tensor in;
+ in.UnsafeCopyFromInternal(inp, src_dtype_, inp.shape());
Tensor* out = nullptr;
- OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out));
- work_(ctx, inp, out);
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in.shape(), &out));
+ out->set_dtype(dst_dtype_);
+ work_(ctx, in, out, use_truncation_);
+ out->set_dtype(external_dst_dtype_);
}
}
Status CastOpBase::Unimplemented() {
- return errors::Unimplemented("Cast ", DataTypeString(src_dtype_), " to ",
- DataTypeString(dst_dtype_), " is not supported");
+ return errors::Unimplemented("Cast ", DataTypeString(external_src_dtype_),
+ " to ", DataTypeString(external_dst_dtype_),
+ " is not supported");
}
CpuCastOp::CpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) {
@@ -78,7 +118,7 @@ CpuCastOp::CpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) {
}
Status CpuCastOp::Prepare() {
- if (src_dtype_ == dst_dtype_) {
+ if (external_src_dtype_ == external_dst_dtype_) {
work_ = nullptr; // Identity
return Status::OK();
}
@@ -86,10 +126,14 @@ Status CpuCastOp::Prepare() {
work_ = GetCpuCastFromBool(dst_dtype_);
} else if (src_dtype_ == DT_UINT8) {
work_ = GetCpuCastFromUint8(dst_dtype_);
- } else if (src_dtype_ == DT_INT8) {
- work_ = GetCpuCastFromInt8(dst_dtype_);
} else if (src_dtype_ == DT_UINT16) {
work_ = GetCpuCastFromUint16(dst_dtype_);
+ } else if (src_dtype_ == DT_UINT32) {
+ work_ = GetCpuCastFromUint32(dst_dtype_);
+ } else if (src_dtype_ == DT_UINT64) {
+ work_ = GetCpuCastFromUint64(dst_dtype_);
+ } else if (src_dtype_ == DT_INT8) {
+ work_ = GetCpuCastFromInt8(dst_dtype_);
} else if (src_dtype_ == DT_INT16) {
work_ = GetCpuCastFromInt16(dst_dtype_);
} else if (src_dtype_ == DT_INT32) {
@@ -127,7 +171,7 @@ class GpuCastOp : public CastOpBase {
private:
Status Prepare() {
- if (src_dtype_ == dst_dtype_) {
+ if (external_src_dtype_ == external_dst_dtype_) {
work_ = nullptr; // Identity
return Status::OK();
}
@@ -135,10 +179,14 @@ class GpuCastOp : public CastOpBase {
work_ = GetGpuCastFromBool(dst_dtype_);
} else if (src_dtype_ == DT_UINT8) {
work_ = GetGpuCastFromUint8(dst_dtype_);
- } else if (src_dtype_ == DT_INT8) {
- work_ = GetGpuCastFromInt8(dst_dtype_);
} else if (src_dtype_ == DT_UINT16) {
work_ = GetGpuCastFromUint16(dst_dtype_);
+ } else if (src_dtype_ == DT_UINT32) {
+ work_ = GetGpuCastFromUint32(dst_dtype_);
+ } else if (src_dtype_ == DT_UINT64) {
+ work_ = GetGpuCastFromUint64(dst_dtype_);
+ } else if (src_dtype_ == DT_INT8) {
+ work_ = GetGpuCastFromInt8(dst_dtype_);
} else if (src_dtype_ == DT_INT16) {
work_ = GetGpuCastFromInt16(dst_dtype_);
} else if (src_dtype_ == DT_INT32) {
@@ -178,8 +226,10 @@ REGISTER_KERNEL_BUILDER(Name("Cast").Device(DEVICE_CPU), CpuCastOp);
CURRY_TYPES2(REGISTER_CAST_GPU, bool);
CURRY_TYPES2(REGISTER_CAST_GPU, uint8);
-CURRY_TYPES2(REGISTER_CAST_GPU, int8);
CURRY_TYPES2(REGISTER_CAST_GPU, uint16);
+CURRY_TYPES2(REGISTER_CAST_GPU, uint32);
+CURRY_TYPES2(REGISTER_CAST_GPU, uint64);
+CURRY_TYPES2(REGISTER_CAST_GPU, int8);
CURRY_TYPES2(REGISTER_CAST_GPU, int16);
CURRY_TYPES2(REGISTER_CAST_GPU, int32);
CURRY_TYPES2(REGISTER_CAST_GPU, int64);
@@ -203,7 +253,7 @@ class SyclCastOp : public CastOpBase {
private:
Status Prepare() {
- if (src_dtype_ == dst_dtype_) {
+ if (external_src_dtype_ == external_dst_dtype_) {
work_ = nullptr; // Identity
return Status::OK();
}