diff options
Diffstat (limited to 'tensorflow')
25 files changed, 342 insertions, 216 deletions
diff --git a/tensorflow/core/framework/bfloat16_test.cc b/tensorflow/core/framework/bfloat16_test.cc index 206396a25a..0a1b5e1975 100644 --- a/tensorflow/core/framework/bfloat16_test.cc +++ b/tensorflow/core/framework/bfloat16_test.cc @@ -45,7 +45,8 @@ class Bfloat16Test : public ::testing::Test, public ::testing::WithParamInterface<Bfloat16TestParam> {}; TEST_P(Bfloat16Test, TruncateTest) { - bfloat16 truncated(GetParam().input); + bfloat16 truncated = bfloat16::truncate_to_bfloat16((GetParam().input)); + if (std::isnan(GetParam().input)) { EXPECT_TRUE(std::isnan(float(truncated)) || std::isinf(float(truncated))); return; diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc index 1b1941f9c1..ea0a814ab8 100644 --- a/tensorflow/core/graph/graph_partition.cc +++ b/tensorflow/core/graph/graph_partition.cc @@ -214,6 +214,14 @@ NodeDef* AddSend(const PartitionOptions& opts, const GraphInfo& g_info, cast_builder.Attr("_start_time", start_time); } cast_builder.Attr("DstT", cast_dtype); + + if (cast_dtype == DT_BFLOAT16) { + // the below attribute specifies that the cast to bfloat16 should use + // truncation. This is needed to retain legacy behavior when we change + // the default bfloat16 casts to use rounding instead of truncation + cast_builder.Attr("Truncate", true); + } + NodeDef* cast = gdef->add_node(); *status = cast_builder.Finalize(cast); if (!status->ok()) return nullptr; diff --git a/tensorflow/core/kernels/cast_op.cc b/tensorflow/core/kernels/cast_op.cc index b4c97df38b..0478c93280 100644 --- a/tensorflow/core/kernels/cast_op.cc +++ b/tensorflow/core/kernels/cast_op.cc @@ -59,6 +59,8 @@ CastOpBase::CastOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT", &external_dst_dtype_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("Truncate", &use_truncation_)); + // Quantized data types use the same underlying format as their non quantized // version so we use the non quantized implementation for casting. if (external_dst_dtype_ == DT_QUINT8) { @@ -100,7 +102,7 @@ void CastOpBase::Compute(OpKernelContext* ctx) { Tensor* out = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in.shape(), &out)); out->set_dtype(dst_dtype_); - work_(ctx, in, out); + work_(ctx, in, out, use_truncation_); out->set_dtype(external_dst_dtype_); } } diff --git a/tensorflow/core/kernels/cast_op.h b/tensorflow/core/kernels/cast_op.h index aae1e7ff19..527ab528c9 100644 --- a/tensorflow/core/kernels/cast_op.h +++ b/tensorflow/core/kernels/cast_op.h @@ -24,8 +24,71 @@ limitations under the License. #include "tensorflow/core/platform/byte_order.h" #include "tensorflow/core/platform/types.h" +// Note that the GPU cast functor templates need to be instantiated unlike the +// CPU ones, and hence their specializations are different than that for CPUs. +#ifdef SPECIALIZE_FOR_GPUS +#define SPECIALIZE_CAST(DEVICE, OUT_TYPE, IN_OUT) \ + template <typename Device> \ + struct CastFunctor<Device, OUT_TYPE, IN_OUT> { \ + void operator()(const Device& d, \ + typename TTypes<OUT_TYPE>::Flat out_tensor, \ + typename TTypes<IN_OUT>::ConstFlat in_tensor, \ + bool truncate = false) { \ + if (truncate) { \ + out_tensor.device(d) = \ + in_tensor.unaryExpr(LSBZeroSetter<IN_OUT, OUT_TYPE>()) \ + .template cast<OUT_TYPE>(); \ + } else { \ + out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>(); \ + } \ + } \ + }; \ + template struct CastFunctor<DEVICE, OUT_TYPE, IN_OUT>; +#else +#define SPECIALIZE_CAST(DEVICE, OUT_TYPE, IN_OUT) \ + template <> \ + struct CastFunctor<DEVICE, OUT_TYPE, IN_OUT> { \ + void operator()(const DEVICE& d, \ + typename TTypes<OUT_TYPE>::Flat out_tensor, \ + typename TTypes<IN_OUT>::ConstFlat in_tensor, \ + bool truncate = false) { \ + if (truncate) { \ + out_tensor.device(d) = \ + in_tensor.unaryExpr(LSBZeroSetter<IN_OUT, OUT_TYPE>()) \ + .template cast<OUT_TYPE>(); \ + } else { \ + out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>(); \ + } \ + } \ + }; +#endif + +#define CAST_FUNCTORS(devname) \ + SPECIALIZE_CAST(devname, float, double) \ + SPECIALIZE_CAST(devname, float, std::complex<double>) \ + SPECIALIZE_CAST(devname, std::complex<float>, std::complex<double>) \ + SPECIALIZE_CAST(devname, std::complex<float>, double) \ + SPECIALIZE_CAST(devname, Eigen::half, double) \ + SPECIALIZE_CAST(devname, Eigen::half, float) \ + SPECIALIZE_CAST(devname, Eigen::half, std::complex<double>) \ + SPECIALIZE_CAST(devname, Eigen::half, std::complex<float>) \ + SPECIALIZE_CAST(devname, bfloat16, float) \ + template <typename OUT_TYPE, typename IN_OUT> \ + struct CastFunctor<devname, OUT_TYPE, IN_OUT> { \ + void operator()(const devname& d, \ + typename TTypes<OUT_TYPE>::Flat out_tensor, \ + typename TTypes<IN_OUT>::ConstFlat in_tensor, \ + bool truncate = false) { \ + out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>(); \ + } \ + }; + namespace tensorflow { +typedef std::function<void(OpKernelContext*, const Tensor&, Tensor*, + bool trunc)> + CastFunctorType; + // Common base class of Cast kernels class CastOpBase : public OpKernel { public: @@ -38,8 +101,8 @@ class CastOpBase : public OpKernel { DataType dst_dtype_; DataType external_src_dtype_; DataType external_dst_dtype_; - std::function<void(OpKernelContext*, const Tensor&, Tensor*)> work_ = nullptr; - + bool use_truncation_; + CastFunctorType work_ = nullptr; Status Unimplemented(); TF_DISALLOW_COPY_AND_ASSIGN(CastOpBase); @@ -56,6 +119,23 @@ class CpuCastOp : public CastOpBase { namespace functor { +template <typename I> +constexpr int MantissaWidth() { + return std::numeric_limits<I>::digits; +} + +template <> +constexpr int MantissaWidth<Eigen::half>() { + // Remember, there's 1 hidden bit + return 10 + 1; +} + +template <> +constexpr int MantissaWidth<bfloat16>() { + // Remember, there's 1 hidden bit + return 7 + 1; +} + template <typename Device, typename Tout, typename Tin> void Cast(const Device& d, typename TTypes<Tout>::Flat o, typename TTypes<Tin>::ConstFlat i) { @@ -65,7 +145,85 @@ void Cast(const Device& d, typename TTypes<Tout>::Flat o, template <typename Device, typename Tout, typename Tin> struct CastFunctor { void operator()(const Device& d, typename TTypes<Tout>::Flat o, - typename TTypes<Tin>::ConstFlat i); + typename TTypes<Tin>::ConstFlat i, bool truncate = false); +}; + +// Only enable LSBZeroSetterHelper for 64 and 32 bit input data types. +// Specialize for others if needed in future. +template <typename I> +typename std::enable_if<sizeof(I) == 8, void>::type EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE static LSBZeroSetterHelper(I& t, int n) { + // Only zero the bits for non-NaNs. + // For NaNs, let the non-truncation version handle it. + if (!std::isnan(t)) { + uint64_t* p = reinterpret_cast<uint64_t*>(&t); + *p &= (0xFFFFFFFFFFFFFFFF << n); + } +} + +template <typename I> +typename std::enable_if<sizeof(I) == 4, void>::type EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE static LSBZeroSetterHelper(I& t, int n) { + // Only zero the bits for non-NaNs. + // For NaNs, let the non-truncation version handle it. + if (!std::isnan(t)) { + uint32_t* p = reinterpret_cast<uint32_t*>(&t); + *p &= (0xFFFFFFFF << n); + } +} + +// Set n least significant bits to 0 +template <typename I, typename O> +struct LSBZeroSetter { + EIGEN_EMPTY_STRUCT_CTOR(LSBZeroSetter) + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const I operator()(const I& a) const { + constexpr int bits = MantissaWidth<I>() - MantissaWidth<O>(); + static_assert( + bits > 0, + "The output type must have fewer mantissa bits than the input type\n"); + I t = a; + LSBZeroSetterHelper(t, bits); + return t; + } +}; + +template <typename I, typename O> +struct LSBZeroSetter<std::complex<I>, std::complex<O>> { + EIGEN_EMPTY_STRUCT_CTOR(LSBZeroSetter) + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<I> operator()( + const std::complex<I>& a) const { + constexpr int bits = MantissaWidth<I>() - MantissaWidth<O>(); + static_assert( + bits > 0, + "The output type must have fewer mantissa bits than the input type\n"); + I re = std::real(a); + I img = std::imag(a); + LSBZeroSetterHelper(re, bits); + LSBZeroSetterHelper(img, bits); + std::complex<I> toReturn(re, img); + return toReturn; + } +}; + +template <typename I, typename O> +struct LSBZeroSetter<std::complex<I>, O> { + EIGEN_EMPTY_STRUCT_CTOR(LSBZeroSetter) + // Sets the 16 LSBits of the float to 0 + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<I> operator()( + const std::complex<I>& a) const { + constexpr int bits = MantissaWidth<I>() - MantissaWidth<O>(); + static_assert( + bits > 0, + "The output type must have fewer mantissa bits than the input type\n"); + I re = std::real(a); + I img = std::imag(a); + LSBZeroSetterHelper(re, bits); + LSBZeroSetterHelper(img, bits); + std::complex<I> toReturn(re, img); + return toReturn; + } }; } // end namespace functor diff --git a/tensorflow/core/kernels/cast_op_gpu.cu.cc b/tensorflow/core/kernels/cast_op_gpu.cu.cc index 607e7f5efd..036996fca2 100644 --- a/tensorflow/core/kernels/cast_op_gpu.cu.cc +++ b/tensorflow/core/kernels/cast_op_gpu.cu.cc @@ -18,22 +18,19 @@ limitations under the License. #define EIGEN_USE_GPU #include "tensorflow/core/framework/bfloat16.h" +#define SPECIALIZE_FOR_GPUS #include "tensorflow/core/kernels/cast_op.h" +#undef SPECIALIZE_FOR_GPUS namespace tensorflow { namespace functor { typedef Eigen::GpuDevice GPUDevice; -template <typename O, typename I> -struct CastFunctor<GPUDevice, O, I> { - void operator()(const GPUDevice& d, typename TTypes<O>::Flat o, - typename TTypes<I>::ConstFlat i) { - Cast<GPUDevice, O, I>(d, o, i); - } -}; +CAST_FUNCTORS(GPUDevice); #define DEFINE(O, I) template struct CastFunctor<GPUDevice, O, I> + #define DEFINE_ALL_FROM(in_type) \ DEFINE(in_type, bool); \ DEFINE(in_type, uint8); \ @@ -59,14 +56,43 @@ DEFINE_ALL_FROM(int8); DEFINE_ALL_FROM(int16); DEFINE_ALL_FROM(int32); DEFINE_ALL_FROM(int64); -DEFINE_ALL_FROM(Eigen::half); -DEFINE_ALL_FROM(float); DEFINE_ALL_FROM(double); -DEFINE_ALL_FROM(std::complex<float>); DEFINE_ALL_FROM(std::complex<double>); -DEFINE(bfloat16, float); DEFINE(float, bfloat16); +#define DEFINE_ALL_TO_FLOAT(out_type) \ + DEFINE(out_type, bool); \ + DEFINE(out_type, uint8); \ + DEFINE(out_type, uint16); \ + DEFINE(out_type, uint32); \ + DEFINE(out_type, uint64); \ + DEFINE(out_type, int8); \ + DEFINE(out_type, int16); \ + DEFINE(out_type, int32); \ + DEFINE(out_type, int64); \ + DEFINE(out_type, Eigen::half); \ + DEFINE(out_type, float); \ + DEFINE(out_type, std::complex<float>) + +#define DEFINE_ALL_TO_HALF(out_type) \ + DEFINE(out_type, bool); \ + DEFINE(out_type, uint8); \ + DEFINE(out_type, uint16); \ + DEFINE(out_type, uint32); \ + DEFINE(out_type, uint64); \ + DEFINE(out_type, int8); \ + DEFINE(out_type, int16); \ + DEFINE(out_type, int32); \ + DEFINE(out_type, int64); \ + DEFINE(out_type, Eigen::half) + +DEFINE_ALL_TO_HALF(Eigen::half); +DEFINE_ALL_TO_HALF(bfloat16); +DEFINE_ALL_TO_FLOAT(float); +DEFINE_ALL_TO_FLOAT(std::complex<float>); + +#undef DEFINE_ALL_TO_FLOAT +#undef DEFINE_ALL_TO_HALF #undef DEFINE_ALL_FROM #undef DEFINE diff --git a/tensorflow/core/kernels/cast_op_impl.h b/tensorflow/core/kernels/cast_op_impl.h index fe821b25df..b899bac681 100644 --- a/tensorflow/core/kernels/cast_op_impl.h +++ b/tensorflow/core/kernels/cast_op_impl.h @@ -25,22 +25,10 @@ namespace tensorflow { namespace functor { -template <typename O, typename I> -struct CastFunctor<Eigen::ThreadPoolDevice, O, I> { - void operator()(const Eigen::ThreadPoolDevice& d, typename TTypes<O>::Flat o, - typename TTypes<I>::ConstFlat i) { - o.device(d) = i.template cast<O>(); - } -}; +CAST_FUNCTORS(Eigen::ThreadPoolDevice); #ifdef TENSORFLOW_USE_SYCL -template <typename O, typename I> -struct CastFunctor<Eigen::SyclDevice, O, I> { - void operator()(const Eigen::SyclDevice& d, typename TTypes<O>::Flat o, - typename TTypes<I>::ConstFlat i) { - o.device(d) = i.template cast<O>(); - } -}; +CAST_FUNCTORS(Eigen::SyclDevice); #endif // TENSORFLOW_USE_SYCL } // namespace functor @@ -68,139 +56,103 @@ struct CastFunctor<Eigen::SyclDevice, O, I> { CURRY_TYPES3_NO_BF16(FN, arg0, arg1) \ FN(arg0, arg1, bfloat16); -#define CAST_CASE(DEVICE, IN, OUT) \ - if (DataTypeToEnum<OUT>::value == dst_dtype) { \ - return [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) { \ - functor::CastFunctor<DEVICE, OUT, IN> func; \ - func(ctx->eigen_device<DEVICE>(), out->flat<OUT>(), inp.flat<IN>()); \ - }; \ +#define CAST_CASE(DEVICE, IN, OUT) \ + if (DataTypeToEnum<OUT>::value == dst_dtype) { \ + return [](OpKernelContext* ctx, const Tensor& inp, Tensor* out, \ + bool truncate) { \ + functor::CastFunctor<DEVICE, OUT, IN> func; \ + func(ctx->eigen_device<DEVICE>(), out->flat<OUT>(), inp.flat<IN>(), \ + truncate); \ + }; \ } // The functions below are implemented in the cast_op_impl_*.cc files. -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromBool(DataType dst_dtype); +CastFunctorType GetCpuCastFromBool(DataType dst_dtype); + +CastFunctorType GetCpuCastFromUint8(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromUint8(DataType dst_dtype); +CastFunctorType GetCpuCastFromUint16(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromUint16(DataType dst_dtype); +CastFunctorType GetCpuCastFromInt8(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromUint32(DataType dst_dtype); +CastFunctorType GetCpuCastFromUint32(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromUint64(DataType dst_dtype); +CastFunctorType GetCpuCastFromUint64(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromInt8(DataType dst_dtype); +CastFunctorType GetCpuCastFromInt8(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromInt16(DataType dst_dtype); +CastFunctorType GetCpuCastFromInt16(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromInt32(DataType dst_dtype); +CastFunctorType GetCpuCastFromInt32(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromInt64(DataType dst_dtype); +CastFunctorType GetCpuCastFromInt64(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromHalf(DataType dst_dtype); +CastFunctorType GetCpuCastFromHalf(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromFloat(DataType dst_dtype); +CastFunctorType GetCpuCastFromFloat(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromDouble(DataType dst_dtype); +CastFunctorType GetCpuCastFromDouble(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromComplex64(DataType dst_dtype); +CastFunctorType GetCpuCastFromComplex64(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromComplex128(DataType dst_dtype); +CastFunctorType GetCpuCastFromComplex128(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromBfloat(DataType dst_dtype); +CastFunctorType GetCpuCastFromBfloat(DataType dst_dtype); #if GOOGLE_CUDA // Same, for GPU. -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromBool(DataType dst_dtype); +CastFunctorType GetGpuCastFromBool(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromUint8(DataType dst_dtype); +CastFunctorType GetGpuCastFromUint8(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromUint16(DataType dst_dtype); +CastFunctorType GetGpuCastFromUint16(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromUint32(DataType dst_dtype); +CastFunctorType GetGpuCastFromInt8(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromUint64(DataType dst_dtype); +CastFunctorType GetGpuCastFromUint32(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromInt8(DataType dst_dtype); +CastFunctorType GetGpuCastFromUint64(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromInt16(DataType dst_dtype); +CastFunctorType GetGpuCastFromInt16(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromInt32(DataType dst_dtype); +CastFunctorType GetGpuCastFromInt32(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromInt64(DataType dst_dtype); +CastFunctorType GetGpuCastFromInt64(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromHalf(DataType dst_dtype); +CastFunctorType GetGpuCastFromHalf(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromFloat(DataType dst_dtype); +CastFunctorType GetGpuCastFromFloat(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromDouble(DataType dst_dtype); +CastFunctorType GetGpuCastFromDouble(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromComplex64(DataType dst_dtype); +CastFunctorType GetGpuCastFromComplex64(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromComplex128(DataType dst_dtype); +CastFunctorType GetGpuCastFromComplex128(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromBfloat(DataType dst_dtype); +CastFunctorType GetGpuCastFromBfloat(DataType dst_dtype); #endif // GOOGLE_CUDA #ifdef TENSORFLOW_USE_SYCL -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromBool(DataType dst_dtype); +CastFunctorType GetSyclCastFromBool(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromUint8(DataType dst_dtype); +CastFunctorType GetSyclCastFromUint8(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromUint16(DataType dst_dtype); +CastFunctorType GetSyclCastFromUint16(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromUint32(DataType dst_dtype); +CastFunctorType GetSyclCastFromUint32(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromUint64(DataType dst_dtype); +CastFunctorType GetSyclCastFromUint64(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromInt16(DataType dst_dtype); +CastFunctorType GetSyclCastFromInt16(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromInt32(DataType dst_dtype); +CastFunctorType GetSyclCastFromInt32(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromInt64(DataType dst_dtype); +CastFunctorType GetSyclCastFromInt64(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromFloat(DataType dst_dtype); +CastFunctorType GetSyclCastFromFloat(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromDouble(DataType dst_dtype); +CastFunctorType GetSyclCastFromDouble(DataType dst_dtype); #endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cast_op_impl_bfloat.cc b/tensorflow/core/kernels/cast_op_impl_bfloat.cc index bfa7ba0d47..96aae15608 100644 --- a/tensorflow/core/kernels/cast_op_impl_bfloat.cc +++ b/tensorflow/core/kernels/cast_op_impl_bfloat.cc @@ -22,20 +22,19 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromBfloat(DataType dst_dtype) { +CastFunctorType GetCpuCastFromBfloat(DataType dst_dtype) { CURRY_TYPES3(CAST_CASE, CPUDevice, bfloat16); return nullptr; } #if GOOGLE_CUDA -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromBfloat(DataType dst_dtype) { +CastFunctorType GetGpuCastFromBfloat(DataType dst_dtype) { if (dst_dtype == DT_FLOAT) { - return [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) { + return [](OpKernelContext* ctx, const Tensor& inp, Tensor* out, + bool truncate) { functor::CastFunctor<GPUDevice, float, bfloat16> func; func(ctx->eigen_device<GPUDevice>(), out->flat<float>(), - inp.flat<bfloat16>()); + inp.flat<bfloat16>(), truncate); }; } return nullptr; diff --git a/tensorflow/core/kernels/cast_op_impl_bool.cc b/tensorflow/core/kernels/cast_op_impl_bool.cc index c5c7394b43..792d4781f2 100644 --- a/tensorflow/core/kernels/cast_op_impl_bool.cc +++ b/tensorflow/core/kernels/cast_op_impl_bool.cc @@ -20,15 +20,13 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromBool(DataType dst_dtype) { +CastFunctorType GetCpuCastFromBool(DataType dst_dtype) { CURRY_TYPES3(CAST_CASE, CPUDevice, bool); return nullptr; } #if GOOGLE_CUDA -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromBool(DataType dst_dtype) { +CastFunctorType GetGpuCastFromBool(DataType dst_dtype) { CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, bool); return nullptr; } @@ -36,8 +34,7 @@ GetGpuCastFromBool(DataType dst_dtype) { #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromBool(DataType dst_dtype) { +CastFunctorType GetSyclCastFromBool(DataType dst_dtype) { CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, bool); return nullptr; } diff --git a/tensorflow/core/kernels/cast_op_impl_complex128.cc b/tensorflow/core/kernels/cast_op_impl_complex128.cc index 52899d58cd..9a184e5954 100644 --- a/tensorflow/core/kernels/cast_op_impl_complex128.cc +++ b/tensorflow/core/kernels/cast_op_impl_complex128.cc @@ -20,15 +20,13 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromComplex128(DataType dst_dtype) { +CastFunctorType GetCpuCastFromComplex128(DataType dst_dtype) { CURRY_TYPES3(CAST_CASE, CPUDevice, std::complex<double>); return nullptr; } #if GOOGLE_CUDA -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromComplex128(DataType dst_dtype) { +CastFunctorType GetGpuCastFromComplex128(DataType dst_dtype) { CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, std::complex<double>); return nullptr; } diff --git a/tensorflow/core/kernels/cast_op_impl_complex64.cc b/tensorflow/core/kernels/cast_op_impl_complex64.cc index 617bda53d5..77bc620b46 100644 --- a/tensorflow/core/kernels/cast_op_impl_complex64.cc +++ b/tensorflow/core/kernels/cast_op_impl_complex64.cc @@ -20,15 +20,13 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromComplex64(DataType dst_dtype) { +CastFunctorType GetCpuCastFromComplex64(DataType dst_dtype) { CURRY_TYPES3(CAST_CASE, CPUDevice, std::complex<float>); return nullptr; } #if GOOGLE_CUDA -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromComplex64(DataType dst_dtype) { +CastFunctorType GetGpuCastFromComplex64(DataType dst_dtype) { CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, std::complex<float>); return nullptr; } diff --git a/tensorflow/core/kernels/cast_op_impl_double.cc b/tensorflow/core/kernels/cast_op_impl_double.cc index 7dc485ddad..ff9056897f 100644 --- a/tensorflow/core/kernels/cast_op_impl_double.cc +++ b/tensorflow/core/kernels/cast_op_impl_double.cc @@ -20,15 +20,13 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromDouble(DataType dst_dtype) { +CastFunctorType GetCpuCastFromDouble(DataType dst_dtype) { CURRY_TYPES3(CAST_CASE, CPUDevice, double); return nullptr; } #if GOOGLE_CUDA -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromDouble(DataType dst_dtype) { +CastFunctorType GetGpuCastFromDouble(DataType dst_dtype) { CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, double); return nullptr; } @@ -36,8 +34,7 @@ GetGpuCastFromDouble(DataType dst_dtype) { #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromDouble(DataType dst_dtype) { +CastFunctorType GetSyclCastFromDouble(DataType dst_dtype) { CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, double); return nullptr; } diff --git a/tensorflow/core/kernels/cast_op_impl_float.cc b/tensorflow/core/kernels/cast_op_impl_float.cc index 1c933914fd..f1e8f0e37b 100644 --- a/tensorflow/core/kernels/cast_op_impl_float.cc +++ b/tensorflow/core/kernels/cast_op_impl_float.cc @@ -22,15 +22,13 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromFloat(DataType dst_dtype) { +CastFunctorType GetCpuCastFromFloat(DataType dst_dtype) { CURRY_TYPES3(CAST_CASE, CPUDevice, float); return nullptr; } #if GOOGLE_CUDA -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromFloat(DataType dst_dtype) { +CastFunctorType GetGpuCastFromFloat(DataType dst_dtype) { CURRY_TYPES3(CAST_CASE, GPUDevice, float); return nullptr; } @@ -38,8 +36,7 @@ GetGpuCastFromFloat(DataType dst_dtype) { #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromFloat(DataType dst_dtype) { +CastFunctorType GetSyclCastFromFloat(DataType dst_dtype) { CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, float); return nullptr; } diff --git a/tensorflow/core/kernels/cast_op_impl_half.cc b/tensorflow/core/kernels/cast_op_impl_half.cc index ef4b94e326..5da3a01352 100644 --- a/tensorflow/core/kernels/cast_op_impl_half.cc +++ b/tensorflow/core/kernels/cast_op_impl_half.cc @@ -20,15 +20,13 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromHalf(DataType dst_dtype) { +CastFunctorType GetCpuCastFromHalf(DataType dst_dtype) { CURRY_TYPES3(CAST_CASE, CPUDevice, Eigen::half); return nullptr; } #if GOOGLE_CUDA -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromHalf(DataType dst_dtype) { +CastFunctorType GetGpuCastFromHalf(DataType dst_dtype) { CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, Eigen::half); return nullptr; } diff --git a/tensorflow/core/kernels/cast_op_impl_int16.cc b/tensorflow/core/kernels/cast_op_impl_int16.cc index 59360f7445..440ee88fb5 100644 --- a/tensorflow/core/kernels/cast_op_impl_int16.cc +++ b/tensorflow/core/kernels/cast_op_impl_int16.cc @@ -20,15 +20,13 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromInt16(DataType dst_dtype) { +CastFunctorType GetCpuCastFromInt16(DataType dst_dtype) { CURRY_TYPES3(CAST_CASE, CPUDevice, int16); return nullptr; } #if GOOGLE_CUDA -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromInt16(DataType dst_dtype) { +CastFunctorType GetGpuCastFromInt16(DataType dst_dtype) { CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, int16); return nullptr; } @@ -36,8 +34,7 @@ GetGpuCastFromInt16(DataType dst_dtype) { #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromInt16(DataType dst_dtype) { +CastFunctorType GetSyclCastFromInt16(DataType dst_dtype) { CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, int16); return nullptr; } diff --git a/tensorflow/core/kernels/cast_op_impl_int32.cc b/tensorflow/core/kernels/cast_op_impl_int32.cc index a867392fde..4b3e7efddc 100644 --- a/tensorflow/core/kernels/cast_op_impl_int32.cc +++ b/tensorflow/core/kernels/cast_op_impl_int32.cc @@ -20,15 +20,13 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromInt32(DataType dst_dtype) { +CastFunctorType GetCpuCastFromInt32(DataType dst_dtype) { CURRY_TYPES3(CAST_CASE, CPUDevice, int32); return nullptr; } #if GOOGLE_CUDA -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromInt32(DataType dst_dtype) { +CastFunctorType GetGpuCastFromInt32(DataType dst_dtype) { CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, int32); return nullptr; } @@ -36,8 +34,7 @@ GetGpuCastFromInt32(DataType dst_dtype) { #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromInt32(DataType dst_dtype) { +CastFunctorType GetSyclCastFromInt32(DataType dst_dtype) { CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, int32); return nullptr; } diff --git a/tensorflow/core/kernels/cast_op_impl_int64.cc b/tensorflow/core/kernels/cast_op_impl_int64.cc index 467a8f6c89..0f711aa560 100644 --- a/tensorflow/core/kernels/cast_op_impl_int64.cc +++ b/tensorflow/core/kernels/cast_op_impl_int64.cc @@ -20,15 +20,13 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromInt64(DataType dst_dtype) { +CastFunctorType GetCpuCastFromInt64(DataType dst_dtype) { CURRY_TYPES3(CAST_CASE, CPUDevice, int64); return nullptr; } #if GOOGLE_CUDA -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromInt64(DataType dst_dtype) { +CastFunctorType GetGpuCastFromInt64(DataType dst_dtype) { CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, int64); return nullptr; } @@ -36,8 +34,7 @@ GetGpuCastFromInt64(DataType dst_dtype) { #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromInt64(DataType dst_dtype) { +CastFunctorType GetSyclCastFromInt64(DataType dst_dtype) { CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, int64); return nullptr; } diff --git a/tensorflow/core/kernels/cast_op_impl_int8.cc b/tensorflow/core/kernels/cast_op_impl_int8.cc index 21002a4321..eac185d5a0 100644 --- a/tensorflow/core/kernels/cast_op_impl_int8.cc +++ b/tensorflow/core/kernels/cast_op_impl_int8.cc @@ -20,15 +20,13 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromInt8(DataType dst_dtype) { +CastFunctorType GetCpuCastFromInt8(DataType dst_dtype) { CURRY_TYPES3(CAST_CASE, CPUDevice, int8); return nullptr; } #if GOOGLE_CUDA -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromInt8(DataType dst_dtype) { +CastFunctorType GetGpuCastFromInt8(DataType dst_dtype) { CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, int8); return nullptr; } @@ -36,8 +34,7 @@ GetGpuCastFromInt8(DataType dst_dtype) { #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromInt8(DataType dst_dtype) { +CastFunctorType GetSyclCastFromInt8(DataType dst_dtype) { CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, int8); return nullptr; } diff --git a/tensorflow/core/kernels/cast_op_impl_uint16.cc b/tensorflow/core/kernels/cast_op_impl_uint16.cc index cd829bae2a..3aebbdc1f3 100644 --- a/tensorflow/core/kernels/cast_op_impl_uint16.cc +++ b/tensorflow/core/kernels/cast_op_impl_uint16.cc @@ -20,15 +20,13 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromUint16(DataType dst_dtype) { +CastFunctorType GetCpuCastFromUint16(DataType dst_dtype) { CURRY_TYPES3(CAST_CASE, CPUDevice, uint16); return nullptr; } #if GOOGLE_CUDA -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromUint16(DataType dst_dtype) { +CastFunctorType GetGpuCastFromUint16(DataType dst_dtype) { CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, uint16); return nullptr; } @@ -36,8 +34,7 @@ GetGpuCastFromUint16(DataType dst_dtype) { #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromUint16(DataType dst_dtype) { +CastFunctorType GetSyclCastFromUint16(DataType dst_dtype) { CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, uint16); return nullptr; } diff --git a/tensorflow/core/kernels/cast_op_impl_uint32.cc b/tensorflow/core/kernels/cast_op_impl_uint32.cc index d1a854d98b..86f5961bcc 100644 --- a/tensorflow/core/kernels/cast_op_impl_uint32.cc +++ b/tensorflow/core/kernels/cast_op_impl_uint32.cc @@ -20,15 +20,13 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromUint32(DataType dst_dtype) { +CastFunctorType GetCpuCastFromUint32(DataType dst_dtype) { CURRY_TYPES3(CAST_CASE, CPUDevice, uint32); return nullptr; } #if GOOGLE_CUDA -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromUint32(DataType dst_dtype) { +CastFunctorType GetGpuCastFromUint32(DataType dst_dtype) { CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, uint32); return nullptr; } @@ -36,8 +34,7 @@ GetGpuCastFromUint32(DataType dst_dtype) { #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromUint32(DataType dst_dtype) { +CastFunctorType GetSyclCastFromUint32(DataType dst_dtype) { CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, uint32); return nullptr; } diff --git a/tensorflow/core/kernels/cast_op_impl_uint64.cc b/tensorflow/core/kernels/cast_op_impl_uint64.cc index 604e0424fc..6478c266ee 100644 --- a/tensorflow/core/kernels/cast_op_impl_uint64.cc +++ b/tensorflow/core/kernels/cast_op_impl_uint64.cc @@ -20,15 +20,13 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromUint64(DataType dst_dtype) { +CastFunctorType GetCpuCastFromUint64(DataType dst_dtype) { CURRY_TYPES3(CAST_CASE, CPUDevice, uint64); return nullptr; } #if GOOGLE_CUDA -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromUint64(DataType dst_dtype) { +CastFunctorType GetGpuCastFromUint64(DataType dst_dtype) { CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, uint64); return nullptr; } @@ -36,8 +34,7 @@ GetGpuCastFromUint64(DataType dst_dtype) { #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromUint64(DataType dst_dtype) { +CastFunctorType GetSyclCastFromUint64(DataType dst_dtype) { CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, uint64); return nullptr; } diff --git a/tensorflow/core/kernels/cast_op_impl_uint8.cc b/tensorflow/core/kernels/cast_op_impl_uint8.cc index 2d1a6f3a4e..b22547a23e 100644 --- a/tensorflow/core/kernels/cast_op_impl_uint8.cc +++ b/tensorflow/core/kernels/cast_op_impl_uint8.cc @@ -20,15 +20,13 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromUint8(DataType dst_dtype) { +CastFunctorType GetCpuCastFromUint8(DataType dst_dtype) { CURRY_TYPES3(CAST_CASE, CPUDevice, uint8); return nullptr; } #if GOOGLE_CUDA -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromUint8(DataType dst_dtype) { +CastFunctorType GetGpuCastFromUint8(DataType dst_dtype) { CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, uint8); return nullptr; } @@ -36,8 +34,7 @@ GetGpuCastFromUint8(DataType dst_dtype) { #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromUint8(DataType dst_dtype) { +CastFunctorType GetSyclCastFromUint8(DataType dst_dtype) { CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, uint8); return nullptr; } diff --git a/tensorflow/core/kernels/cast_op_test.cc b/tensorflow/core/kernels/cast_op_test.cc index 9bbf7afb16..cb305de5e3 100644 --- a/tensorflow/core/kernels/cast_op_test.cc +++ b/tensorflow/core/kernels/cast_op_test.cc @@ -40,17 +40,27 @@ static Graph* Cast(int num) { class CastOpTest : public OpsTestBase { protected: - void MakeOp(DataType src, DataType dst) { - TF_EXPECT_OK(NodeDefBuilder("cast_op", "Cast") - .Input(FakeInput(src)) - .Attr("SrcT", src) - .Attr("DstT", dst) - .Finalize(node_def())); + void MakeOp(DataType src, DataType dst, bool trunc = false) { + if (trunc) { + TF_EXPECT_OK(NodeDefBuilder("cast_op", "Cast") + .Input(FakeInput(src)) + .Attr("SrcT", src) + .Attr("DstT", dst) + .Attr("Truncate", true) + .Finalize(node_def())); + } else { + TF_EXPECT_OK(NodeDefBuilder("cast_op", "Cast") + .Input(FakeInput(src)) + .Attr("SrcT", src) + .Attr("DstT", dst) + .Finalize(node_def())); + } + TF_EXPECT_OK(InitOp()); } template <typename INPUT, typename OUTPUT> - void CheckCast() { + void CheckCast(bool trunc = false) { DataType in_type = DataTypeToEnum<INPUT>::v(); DataType out_type = DataTypeToEnum<OUTPUT>::v(); MakeOp(in_type, out_type); @@ -64,8 +74,9 @@ class CastOpTest : public OpsTestBase { } }; -#define TEST_CAST(in, out) \ - TEST_F(CastOpTest, TestCast##_##in##_##out) { CheckCast<in, out>(); } +#define TEST_CAST(in, out) \ + TEST_F(CastOpTest, TestCast##_##in##_##out) { CheckCast<in, out>(); } \ + TEST_F(CastOpTest, TestCast2##_##in##_##out) { CheckCast<in, out>(true); } #define TEST_ALL_CASTS_FROM(in) \ TEST_CAST(in, uint8); \ diff --git a/tensorflow/core/lib/bfloat16/bfloat16.h b/tensorflow/core/lib/bfloat16/bfloat16.h index 1c130ba300..d6f3f26cd5 100644 --- a/tensorflow/core/lib/bfloat16/bfloat16.h +++ b/tensorflow/core/lib/bfloat16/bfloat16.h @@ -45,17 +45,25 @@ typedef std::complex<double> complex128; struct bfloat16 { B16_DEVICE_FUNC bfloat16() {} - B16_DEVICE_FUNC explicit bfloat16(const float v) { + B16_DEVICE_FUNC static bfloat16 truncate_to_bfloat16(const float v) { + bfloat16 output; if (float_isnan(v)) { - value = NAN_VALUE; - return; + output.value = NAN_VALUE; + return output; } const uint16_t* p = reinterpret_cast<const uint16_t*>(&v); #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ - value = p[0]; + output.value = p[0]; #else - value = p[1]; + output.value = p[1]; #endif + return output; + } + + B16_DEVICE_FUNC explicit bfloat16(const float v) { + // TODO(asabne) : change the below line to + // value = round_to_bfloat16(v).value; + value = truncate_to_bfloat16(v).value; } B16_DEVICE_FUNC explicit bfloat16(const double val) @@ -169,8 +177,6 @@ struct bfloat16 { // Converts a float point to bfloat16, with round-nearest-to-even as rounding // method. - // TODO(b/69266521): Add a truncate_to_bfloat16 function and make this - // function as default behavior. // TODO: There is a slightly faster implementation (8% faster on CPU) // than this (documented in cl/175987786), that is exponentially harder to // understand and document. Switch to the faster version when converting to diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 386ae9635a..77697756c4 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -114,6 +114,7 @@ REGISTER_OP("Cast") .Output("y: DstT") .Attr("SrcT: type") .Attr("DstT: type") + .Attr("Truncate: bool = false") .SetShapeFn(shape_inference::UnchangedShape); REGISTER_OP("_HostCast") diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc index cefd5b1206..15d2ccf9d2 100644 --- a/tensorflow/python/eager/pywrap_tensor.cc +++ b/tensorflow/python/eager/pywrap_tensor.cc @@ -154,6 +154,7 @@ TFE_TensorHandle* EagerCast(TFE_Context* ctx, TFE_TensorHandle* handle, if (TF_GetCode(out_status) != TF_OK) RETURN_ERROR TFE_OpSetAttrType(op, "SrcT", src_type_enum); TFE_OpSetAttrType(op, "DstT", dst_type_enum); + TFE_OpSetAttrBool(op, "Truncate", false); TFE_TensorHandle* output = nullptr; int num_outputs = 1; TFE_Execute(op, &output, &num_outputs, out_status); |