// See docs in ../ops/math_ops.cc. #define EIGEN_USE_THREADS #include "tensorflow/core/kernels/cast_op.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/port.h" #include "tensorflow/core/util/work_sharder.h" namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; namespace functor { template void CastMaybeInline(const Device& d, typename TTypes::Flat o, typename TTypes::ConstFlat i) { if (o.size() * (sizeof(Tin) + sizeof(Tout)) < 131072) { // Small cast on a CPU: do inline o = i.template cast(); } else { o.device(d) = i.template cast(); } } template struct CastFunctor { void operator()(const CPUDevice& d, typename TTypes::Flat o, typename TTypes::ConstFlat i) { CastMaybeInline(d, o, i); } }; } // namespace functor #define CAST_CASE(DEVICE, IN, OUT) \ if (DataTypeToEnum::value == src_dtype_ && \ DataTypeToEnum::value == dst_dtype_) { \ work_ = [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) { \ functor::CastFunctor func; \ func(ctx->eigen_device(), out->flat(), inp.flat()); \ }; \ return Status::OK(); \ } class CastOpBase : public OpKernel { public: explicit CastOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("SrcT", &src_dtype_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT", &dst_dtype_)); } void Compute(OpKernelContext* ctx) override { const Tensor& inp = ctx->input(0); if (work_ == nullptr) { ctx->set_output(0, inp); } else { Tensor* out = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out)); work_(ctx, inp, out); } } protected: DataType src_dtype_; DataType dst_dtype_; std::function work_ = nullptr; virtual Status Prepare() = 0; Status Unimplemented() { return errors::Unimplemented("Cast ", DataTypeString(src_dtype_), " to ", DataTypeString(dst_dtype_), " is not supported"); } TF_DISALLOW_COPY_AND_ASSIGN(CastOpBase); }; class CpuCastOp : public CastOpBase { public: explicit CpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) { OP_REQUIRES_OK(ctx, Prepare()); } protected: Status Prepare() override { if (src_dtype_ == dst_dtype_) { work_ = nullptr; // Identity return Status::OK(); } CAST_CASE(CPUDevice, bool, float); CAST_CASE(CPUDevice, bool, int32); CAST_CASE(CPUDevice, bool, double); CAST_CASE(CPUDevice, double, float); CAST_CASE(CPUDevice, double, int32); CAST_CASE(CPUDevice, double, int64); CAST_CASE(CPUDevice, float, double); CAST_CASE(CPUDevice, float, uint8); CAST_CASE(CPUDevice, float, int32); CAST_CASE(CPUDevice, float, int64); CAST_CASE(CPUDevice, int32, double); CAST_CASE(CPUDevice, int32, float); CAST_CASE(CPUDevice, int32, uint8); CAST_CASE(CPUDevice, int32, int64); CAST_CASE(CPUDevice, int64, double); CAST_CASE(CPUDevice, int64, float); CAST_CASE(CPUDevice, int64, int32); CAST_CASE(CPUDevice, uint8, float); CAST_CASE(CPUDevice, uint8, int32); CAST_CASE(CPUDevice, uint8, int64); CAST_CASE(CPUDevice, uint8, double); if (src_dtype_ == DT_BFLOAT16 && dst_dtype_ == DT_FLOAT) { work_ = [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) { int64 N = out->NumElements(); auto worker_threads = ctx->device()->tensorflow_cpu_worker_threads(); int num_threads = std::min(std::min(4, worker_threads->num_threads), N / 4096); if (num_threads < 1) { BFloat16ToFloat(inp.flat().data(), out->flat().data(), N); } else { auto work = [&inp, &out](int64 start, int64 end) { BFloat16ToFloat(inp.flat().data() + start, out->flat().data() + start, end - start); }; Shard(num_threads, worker_threads->workers, N, 100, work); } }; return Status::OK(); } if (src_dtype_ == DT_FLOAT && dst_dtype_ == DT_BFLOAT16) { work_ = [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) { int64 N = out->NumElements(); auto worker_threads = ctx->device()->tensorflow_cpu_worker_threads(); int num_threads = std::min(std::min(4, worker_threads->num_threads), N / 4096); if (num_threads < 1) { FloatToBFloat16(inp.flat().data(), out->flat().data(), N); } else { auto work = [&inp, &out](int64 start, int64 end) { FloatToBFloat16(inp.flat().data() + start, out->flat().data() + start, end - start); }; Shard(num_threads, worker_threads->workers, N, 100, work); } }; return Status::OK(); } return Unimplemented(); } }; class GpuCastOp : public CastOpBase { public: explicit GpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) { OP_REQUIRES_OK(ctx, Prepare()); } protected: Status Prepare() override { if (src_dtype_ == dst_dtype_) { work_ = nullptr; // Identity return Status::OK(); } CAST_CASE(GPUDevice, bfloat16, float); CAST_CASE(GPUDevice, bool, float); CAST_CASE(GPUDevice, double, float); CAST_CASE(GPUDevice, double, int64); CAST_CASE(GPUDevice, float, bfloat16); CAST_CASE(GPUDevice, float, double); CAST_CASE(GPUDevice, float, int64); CAST_CASE(GPUDevice, int64, double); CAST_CASE(GPUDevice, int64, float); CAST_CASE(GPUDevice, uint8, float); CAST_CASE(GPUDevice, float, uint8); CAST_CASE(GPUDevice, bool, int32); CAST_CASE(GPUDevice, double, int32); CAST_CASE(GPUDevice, float, int32); CAST_CASE(GPUDevice, int32, double); CAST_CASE(GPUDevice, int32, float); CAST_CASE(GPUDevice, int32, int64); CAST_CASE(GPUDevice, int64, int32); return Unimplemented(); } }; #undef CAST_CASE REGISTER_KERNEL_BUILDER(Name("Cast").Device(DEVICE_CPU), CpuCastOp); #if GOOGLE_CUDA #define REGISTER_CAST_GPU(srctype, dsttype) \ REGISTER_KERNEL_BUILDER(Name("Cast") \ .TypeConstraint("SrcT") \ .TypeConstraint("DstT") \ .Device(DEVICE_GPU), \ GpuCastOp); REGISTER_CAST_GPU(bfloat16, float); REGISTER_CAST_GPU(bool, float); REGISTER_CAST_GPU(double, float); REGISTER_CAST_GPU(double, int64); REGISTER_CAST_GPU(float, bfloat16); REGISTER_CAST_GPU(float, double); REGISTER_CAST_GPU(float, int64); REGISTER_CAST_GPU(int64, double); REGISTER_CAST_GPU(int64, float); REGISTER_CAST_GPU(uint8, float); REGISTER_CAST_GPU(float, uint8); REGISTER_CAST_GPU(bool, int32); REGISTER_CAST_GPU(double, int32); REGISTER_CAST_GPU(float, int32); REGISTER_CAST_GPU(int32, double); REGISTER_CAST_GPU(int32, float); REGISTER_CAST_GPU(int32, int64); REGISTER_CAST_GPU(int64, int32); #undef REGISTER_CAST_GPU #endif // GOOGLE_CUDA // HostCast differs from Cast in that its input and output are in host memory. REGISTER_KERNEL_BUILDER(Name("_HostCast").Device(DEVICE_CPU), CpuCastOp); REGISTER_KERNEL_BUILDER( Name("_HostCast").Device(DEVICE_GPU).HostMemory("x").HostMemory("y"), CpuCastOp); } // end namespace tensorflow