diff options
Diffstat (limited to 'tensorflow/core/kernels/cast_op.cc')
-rw-r--r-- | tensorflow/core/kernels/cast_op.cc | 80 |
1 files changed, 65 insertions, 15 deletions
diff --git a/tensorflow/core/kernels/cast_op.cc b/tensorflow/core/kernels/cast_op.cc index 626db9131a..0478c93280 100644 --- a/tensorflow/core/kernels/cast_op.cc +++ b/tensorflow/core/kernels/cast_op.cc @@ -41,8 +41,10 @@ typedef Eigen::SyclDevice SYCLDevice; #define CURRY_TYPES2(FN, arg0) \ FN(arg0, bool); \ FN(arg0, uint8); \ - FN(arg0, int8); \ FN(arg0, uint16); \ + FN(arg0, uint32); \ + FN(arg0, uint64); \ + FN(arg0, int8); \ FN(arg0, int16); \ FN(arg0, int32); \ FN(arg0, int64); \ @@ -53,8 +55,41 @@ typedef Eigen::SyclDevice SYCLDevice; FN(arg0, std::complex<double>) CastOpBase::CastOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("SrcT", &src_dtype_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT", &dst_dtype_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("SrcT", &external_src_dtype_)); + + OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT", &external_dst_dtype_)); + + OP_REQUIRES_OK(ctx, ctx->GetAttr("Truncate", &use_truncation_)); + + // Quantized data types use the same underlying format as their non quantized + // version so we use the non quantized implementation for casting. + if (external_dst_dtype_ == DT_QUINT8) { + dst_dtype_ = DT_UINT8; + } else if (external_dst_dtype_ == DT_QINT8) { + dst_dtype_ = DT_INT8; + } else if (external_dst_dtype_ == DT_QINT32) { + dst_dtype_ = DT_INT32; + } else if (external_dst_dtype_ == DT_QINT16) { + dst_dtype_ = DT_INT16; + } else if (external_dst_dtype_ == DT_QUINT16) { + dst_dtype_ = DT_UINT16; + } else { + dst_dtype_ = external_dst_dtype_; + } + + if (external_src_dtype_ == DT_QUINT8) { + src_dtype_ = DT_UINT8; + } else if (external_src_dtype_ == DT_QINT8) { + src_dtype_ = DT_INT8; + } else if (external_src_dtype_ == DT_QINT32) { + src_dtype_ = DT_INT32; + } else if (external_src_dtype_ == DT_QINT16) { + src_dtype_ = DT_INT16; + } else if (external_src_dtype_ == DT_QUINT16) { + src_dtype_ = DT_UINT16; + } else { + src_dtype_ = external_src_dtype_; + } } void CastOpBase::Compute(OpKernelContext* ctx) { @@ -62,15 +97,20 @@ void CastOpBase::Compute(OpKernelContext* ctx) { if (work_ == nullptr) { ctx->set_output(0, inp); } else { + Tensor in; + in.UnsafeCopyFromInternal(inp, src_dtype_, inp.shape()); Tensor* out = nullptr; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out)); - work_(ctx, inp, out); + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in.shape(), &out)); + out->set_dtype(dst_dtype_); + work_(ctx, in, out, use_truncation_); + out->set_dtype(external_dst_dtype_); } } Status CastOpBase::Unimplemented() { - return errors::Unimplemented("Cast ", DataTypeString(src_dtype_), " to ", - DataTypeString(dst_dtype_), " is not supported"); + return errors::Unimplemented("Cast ", DataTypeString(external_src_dtype_), + " to ", DataTypeString(external_dst_dtype_), + " is not supported"); } CpuCastOp::CpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) { @@ -78,7 +118,7 @@ CpuCastOp::CpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) { } Status CpuCastOp::Prepare() { - if (src_dtype_ == dst_dtype_) { + if (external_src_dtype_ == external_dst_dtype_) { work_ = nullptr; // Identity return Status::OK(); } @@ -86,10 +126,14 @@ Status CpuCastOp::Prepare() { work_ = GetCpuCastFromBool(dst_dtype_); } else if (src_dtype_ == DT_UINT8) { work_ = GetCpuCastFromUint8(dst_dtype_); - } else if (src_dtype_ == DT_INT8) { - work_ = GetCpuCastFromInt8(dst_dtype_); } else if (src_dtype_ == DT_UINT16) { work_ = GetCpuCastFromUint16(dst_dtype_); + } else if (src_dtype_ == DT_UINT32) { + work_ = GetCpuCastFromUint32(dst_dtype_); + } else if (src_dtype_ == DT_UINT64) { + work_ = GetCpuCastFromUint64(dst_dtype_); + } else if (src_dtype_ == DT_INT8) { + work_ = GetCpuCastFromInt8(dst_dtype_); } else if (src_dtype_ == DT_INT16) { work_ = GetCpuCastFromInt16(dst_dtype_); } else if (src_dtype_ == DT_INT32) { @@ -127,7 +171,7 @@ class GpuCastOp : public CastOpBase { private: Status Prepare() { - if (src_dtype_ == dst_dtype_) { + if (external_src_dtype_ == external_dst_dtype_) { work_ = nullptr; // Identity return Status::OK(); } @@ -135,10 +179,14 @@ class GpuCastOp : public CastOpBase { work_ = GetGpuCastFromBool(dst_dtype_); } else if (src_dtype_ == DT_UINT8) { work_ = GetGpuCastFromUint8(dst_dtype_); - } else if (src_dtype_ == DT_INT8) { - work_ = GetGpuCastFromInt8(dst_dtype_); } else if (src_dtype_ == DT_UINT16) { work_ = GetGpuCastFromUint16(dst_dtype_); + } else if (src_dtype_ == DT_UINT32) { + work_ = GetGpuCastFromUint32(dst_dtype_); + } else if (src_dtype_ == DT_UINT64) { + work_ = GetGpuCastFromUint64(dst_dtype_); + } else if (src_dtype_ == DT_INT8) { + work_ = GetGpuCastFromInt8(dst_dtype_); } else if (src_dtype_ == DT_INT16) { work_ = GetGpuCastFromInt16(dst_dtype_); } else if (src_dtype_ == DT_INT32) { @@ -178,8 +226,10 @@ REGISTER_KERNEL_BUILDER(Name("Cast").Device(DEVICE_CPU), CpuCastOp); CURRY_TYPES2(REGISTER_CAST_GPU, bool); CURRY_TYPES2(REGISTER_CAST_GPU, uint8); -CURRY_TYPES2(REGISTER_CAST_GPU, int8); CURRY_TYPES2(REGISTER_CAST_GPU, uint16); +CURRY_TYPES2(REGISTER_CAST_GPU, uint32); +CURRY_TYPES2(REGISTER_CAST_GPU, uint64); +CURRY_TYPES2(REGISTER_CAST_GPU, int8); CURRY_TYPES2(REGISTER_CAST_GPU, int16); CURRY_TYPES2(REGISTER_CAST_GPU, int32); CURRY_TYPES2(REGISTER_CAST_GPU, int64); @@ -203,7 +253,7 @@ class SyclCastOp : public CastOpBase { private: Status Prepare() { - if (src_dtype_ == dst_dtype_) { + if (external_src_dtype_ == external_dst_dtype_) { work_ = nullptr; // Identity return Status::OK(); } |