aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-25 08:23:57 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-25 08:27:36 -0700
commitb3771feab49e2122164737a860341727d08c2d8c (patch)
tree5fb440041db26ef96eb14e7491cb67fe06e7c3d4
parentbe3d22844025e42e177a21479f3ae73bc5351c1f (diff)
This change started with an intention of adding an attribute to cast ops to decide
whether bfloat16 casts should use truncation or rounding. This is a preparatory change before we switch the default float ==> bfloat16 cast to use rounding instead of truncation. The attribute added can then be specified on casts that rely on the truncation, e.g., the TensorFlow send/receive operations. It later emerged that the choice of doing truncation is useful more generally. Therefore, this change allows the new attribute to be used by all relevant casts to use truncation instead of rounding. PiperOrigin-RevId: 205996367
-rw-r--r--tensorflow/core/framework/bfloat16_test.cc3
-rw-r--r--tensorflow/core/graph/graph_partition.cc8
-rw-r--r--tensorflow/core/kernels/cast_op.cc4
-rw-r--r--tensorflow/core/kernels/cast_op.h164
-rw-r--r--tensorflow/core/kernels/cast_op_gpu.cu.cc48
-rw-r--r--tensorflow/core/kernels/cast_op_impl.h152
-rw-r--r--tensorflow/core/kernels/cast_op_impl_bfloat.cc11
-rw-r--r--tensorflow/core/kernels/cast_op_impl_bool.cc9
-rw-r--r--tensorflow/core/kernels/cast_op_impl_complex128.cc6
-rw-r--r--tensorflow/core/kernels/cast_op_impl_complex64.cc6
-rw-r--r--tensorflow/core/kernels/cast_op_impl_double.cc9
-rw-r--r--tensorflow/core/kernels/cast_op_impl_float.cc9
-rw-r--r--tensorflow/core/kernels/cast_op_impl_half.cc6
-rw-r--r--tensorflow/core/kernels/cast_op_impl_int16.cc9
-rw-r--r--tensorflow/core/kernels/cast_op_impl_int32.cc9
-rw-r--r--tensorflow/core/kernels/cast_op_impl_int64.cc9
-rw-r--r--tensorflow/core/kernels/cast_op_impl_int8.cc9
-rw-r--r--tensorflow/core/kernels/cast_op_impl_uint16.cc9
-rw-r--r--tensorflow/core/kernels/cast_op_impl_uint32.cc9
-rw-r--r--tensorflow/core/kernels/cast_op_impl_uint64.cc9
-rw-r--r--tensorflow/core/kernels/cast_op_impl_uint8.cc9
-rw-r--r--tensorflow/core/kernels/cast_op_test.cc29
-rw-r--r--tensorflow/core/lib/bfloat16/bfloat16.h20
-rw-r--r--tensorflow/core/ops/math_ops.cc1
-rw-r--r--tensorflow/python/eager/pywrap_tensor.cc1
25 files changed, 342 insertions, 216 deletions
diff --git a/tensorflow/core/framework/bfloat16_test.cc b/tensorflow/core/framework/bfloat16_test.cc
index 206396a25a..0a1b5e1975 100644
--- a/tensorflow/core/framework/bfloat16_test.cc
+++ b/tensorflow/core/framework/bfloat16_test.cc
@@ -45,7 +45,8 @@ class Bfloat16Test : public ::testing::Test,
public ::testing::WithParamInterface<Bfloat16TestParam> {};
TEST_P(Bfloat16Test, TruncateTest) {
- bfloat16 truncated(GetParam().input);
+ bfloat16 truncated = bfloat16::truncate_to_bfloat16((GetParam().input));
+
if (std::isnan(GetParam().input)) {
EXPECT_TRUE(std::isnan(float(truncated)) || std::isinf(float(truncated)));
return;
diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc
index 1b1941f9c1..ea0a814ab8 100644
--- a/tensorflow/core/graph/graph_partition.cc
+++ b/tensorflow/core/graph/graph_partition.cc
@@ -214,6 +214,14 @@ NodeDef* AddSend(const PartitionOptions& opts, const GraphInfo& g_info,
cast_builder.Attr("_start_time", start_time);
}
cast_builder.Attr("DstT", cast_dtype);
+
+ if (cast_dtype == DT_BFLOAT16) {
+ // the below attribute specifies that the cast to bfloat16 should use
+ // truncation. This is needed to retain legacy behavior when we change
+ // the default bfloat16 casts to use rounding instead of truncation
+ cast_builder.Attr("Truncate", true);
+ }
+
NodeDef* cast = gdef->add_node();
*status = cast_builder.Finalize(cast);
if (!status->ok()) return nullptr;
diff --git a/tensorflow/core/kernels/cast_op.cc b/tensorflow/core/kernels/cast_op.cc
index b4c97df38b..0478c93280 100644
--- a/tensorflow/core/kernels/cast_op.cc
+++ b/tensorflow/core/kernels/cast_op.cc
@@ -59,6 +59,8 @@ CastOpBase::CastOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT", &external_dst_dtype_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("Truncate", &use_truncation_));
+
// Quantized data types use the same underlying format as their non quantized
// version so we use the non quantized implementation for casting.
if (external_dst_dtype_ == DT_QUINT8) {
@@ -100,7 +102,7 @@ void CastOpBase::Compute(OpKernelContext* ctx) {
Tensor* out = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in.shape(), &out));
out->set_dtype(dst_dtype_);
- work_(ctx, in, out);
+ work_(ctx, in, out, use_truncation_);
out->set_dtype(external_dst_dtype_);
}
}
diff --git a/tensorflow/core/kernels/cast_op.h b/tensorflow/core/kernels/cast_op.h
index aae1e7ff19..527ab528c9 100644
--- a/tensorflow/core/kernels/cast_op.h
+++ b/tensorflow/core/kernels/cast_op.h
@@ -24,8 +24,71 @@ limitations under the License.
#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/types.h"
+// Note that the GPU cast functor templates need to be instantiated unlike the
+// CPU ones, and hence their specializations are different than that for CPUs.
+#ifdef SPECIALIZE_FOR_GPUS
+#define SPECIALIZE_CAST(DEVICE, OUT_TYPE, IN_OUT) \
+ template <typename Device> \
+ struct CastFunctor<Device, OUT_TYPE, IN_OUT> { \
+ void operator()(const Device& d, \
+ typename TTypes<OUT_TYPE>::Flat out_tensor, \
+ typename TTypes<IN_OUT>::ConstFlat in_tensor, \
+ bool truncate = false) { \
+ if (truncate) { \
+ out_tensor.device(d) = \
+ in_tensor.unaryExpr(LSBZeroSetter<IN_OUT, OUT_TYPE>()) \
+ .template cast<OUT_TYPE>(); \
+ } else { \
+ out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>(); \
+ } \
+ } \
+ }; \
+ template struct CastFunctor<DEVICE, OUT_TYPE, IN_OUT>;
+#else
+#define SPECIALIZE_CAST(DEVICE, OUT_TYPE, IN_OUT) \
+ template <> \
+ struct CastFunctor<DEVICE, OUT_TYPE, IN_OUT> { \
+ void operator()(const DEVICE& d, \
+ typename TTypes<OUT_TYPE>::Flat out_tensor, \
+ typename TTypes<IN_OUT>::ConstFlat in_tensor, \
+ bool truncate = false) { \
+ if (truncate) { \
+ out_tensor.device(d) = \
+ in_tensor.unaryExpr(LSBZeroSetter<IN_OUT, OUT_TYPE>()) \
+ .template cast<OUT_TYPE>(); \
+ } else { \
+ out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>(); \
+ } \
+ } \
+ };
+#endif
+
+#define CAST_FUNCTORS(devname) \
+ SPECIALIZE_CAST(devname, float, double) \
+ SPECIALIZE_CAST(devname, float, std::complex<double>) \
+ SPECIALIZE_CAST(devname, std::complex<float>, std::complex<double>) \
+ SPECIALIZE_CAST(devname, std::complex<float>, double) \
+ SPECIALIZE_CAST(devname, Eigen::half, double) \
+ SPECIALIZE_CAST(devname, Eigen::half, float) \
+ SPECIALIZE_CAST(devname, Eigen::half, std::complex<double>) \
+ SPECIALIZE_CAST(devname, Eigen::half, std::complex<float>) \
+ SPECIALIZE_CAST(devname, bfloat16, float) \
+ template <typename OUT_TYPE, typename IN_OUT> \
+ struct CastFunctor<devname, OUT_TYPE, IN_OUT> { \
+ void operator()(const devname& d, \
+ typename TTypes<OUT_TYPE>::Flat out_tensor, \
+ typename TTypes<IN_OUT>::ConstFlat in_tensor, \
+ bool truncate = false) { \
+ out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>(); \
+ } \
+ };
+
namespace tensorflow {
+typedef std::function<void(OpKernelContext*, const Tensor&, Tensor*,
+ bool trunc)>
+ CastFunctorType;
+
// Common base class of Cast kernels
class CastOpBase : public OpKernel {
public:
@@ -38,8 +101,8 @@ class CastOpBase : public OpKernel {
DataType dst_dtype_;
DataType external_src_dtype_;
DataType external_dst_dtype_;
- std::function<void(OpKernelContext*, const Tensor&, Tensor*)> work_ = nullptr;
-
+ bool use_truncation_;
+ CastFunctorType work_ = nullptr;
Status Unimplemented();
TF_DISALLOW_COPY_AND_ASSIGN(CastOpBase);
@@ -56,6 +119,23 @@ class CpuCastOp : public CastOpBase {
namespace functor {
+template <typename I>
+constexpr int MantissaWidth() {
+ return std::numeric_limits<I>::digits;
+}
+
+template <>
+constexpr int MantissaWidth<Eigen::half>() {
+ // Remember, there's 1 hidden bit
+ return 10 + 1;
+}
+
+template <>
+constexpr int MantissaWidth<bfloat16>() {
+ // Remember, there's 1 hidden bit
+ return 7 + 1;
+}
+
template <typename Device, typename Tout, typename Tin>
void Cast(const Device& d, typename TTypes<Tout>::Flat o,
typename TTypes<Tin>::ConstFlat i) {
@@ -65,7 +145,85 @@ void Cast(const Device& d, typename TTypes<Tout>::Flat o,
template <typename Device, typename Tout, typename Tin>
struct CastFunctor {
void operator()(const Device& d, typename TTypes<Tout>::Flat o,
- typename TTypes<Tin>::ConstFlat i);
+ typename TTypes<Tin>::ConstFlat i, bool truncate = false);
+};
+
+// Only enable LSBZeroSetterHelper for 64 and 32 bit input data types.
+// Specialize for others if needed in future.
+template <typename I>
+typename std::enable_if<sizeof(I) == 8, void>::type EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE static LSBZeroSetterHelper(I& t, int n) {
+ // Only zero the bits for non-NaNs.
+ // For NaNs, let the non-truncation version handle it.
+ if (!std::isnan(t)) {
+ uint64_t* p = reinterpret_cast<uint64_t*>(&t);
+ *p &= (0xFFFFFFFFFFFFFFFF << n);
+ }
+}
+
+template <typename I>
+typename std::enable_if<sizeof(I) == 4, void>::type EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE static LSBZeroSetterHelper(I& t, int n) {
+ // Only zero the bits for non-NaNs.
+ // For NaNs, let the non-truncation version handle it.
+ if (!std::isnan(t)) {
+ uint32_t* p = reinterpret_cast<uint32_t*>(&t);
+ *p &= (0xFFFFFFFF << n);
+ }
+}
+
+// Set n least significant bits to 0
+template <typename I, typename O>
+struct LSBZeroSetter {
+ EIGEN_EMPTY_STRUCT_CTOR(LSBZeroSetter)
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const I operator()(const I& a) const {
+ constexpr int bits = MantissaWidth<I>() - MantissaWidth<O>();
+ static_assert(
+ bits > 0,
+ "The output type must have fewer mantissa bits than the input type\n");
+ I t = a;
+ LSBZeroSetterHelper(t, bits);
+ return t;
+ }
+};
+
+template <typename I, typename O>
+struct LSBZeroSetter<std::complex<I>, std::complex<O>> {
+ EIGEN_EMPTY_STRUCT_CTOR(LSBZeroSetter)
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<I> operator()(
+ const std::complex<I>& a) const {
+ constexpr int bits = MantissaWidth<I>() - MantissaWidth<O>();
+ static_assert(
+ bits > 0,
+ "The output type must have fewer mantissa bits than the input type\n");
+ I re = std::real(a);
+ I img = std::imag(a);
+ LSBZeroSetterHelper(re, bits);
+ LSBZeroSetterHelper(img, bits);
+ std::complex<I> toReturn(re, img);
+ return toReturn;
+ }
+};
+
+template <typename I, typename O>
+struct LSBZeroSetter<std::complex<I>, O> {
+ EIGEN_EMPTY_STRUCT_CTOR(LSBZeroSetter)
+ // Sets the 16 LSBits of the float to 0
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<I> operator()(
+ const std::complex<I>& a) const {
+ constexpr int bits = MantissaWidth<I>() - MantissaWidth<O>();
+ static_assert(
+ bits > 0,
+ "The output type must have fewer mantissa bits than the input type\n");
+ I re = std::real(a);
+ I img = std::imag(a);
+ LSBZeroSetterHelper(re, bits);
+ LSBZeroSetterHelper(img, bits);
+ std::complex<I> toReturn(re, img);
+ return toReturn;
+ }
};
} // end namespace functor
diff --git a/tensorflow/core/kernels/cast_op_gpu.cu.cc b/tensorflow/core/kernels/cast_op_gpu.cu.cc
index 607e7f5efd..036996fca2 100644
--- a/tensorflow/core/kernels/cast_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/cast_op_gpu.cu.cc
@@ -18,22 +18,19 @@ limitations under the License.
#define EIGEN_USE_GPU
#include "tensorflow/core/framework/bfloat16.h"
+#define SPECIALIZE_FOR_GPUS
#include "tensorflow/core/kernels/cast_op.h"
+#undef SPECIALIZE_FOR_GPUS
namespace tensorflow {
namespace functor {
typedef Eigen::GpuDevice GPUDevice;
-template <typename O, typename I>
-struct CastFunctor<GPUDevice, O, I> {
- void operator()(const GPUDevice& d, typename TTypes<O>::Flat o,
- typename TTypes<I>::ConstFlat i) {
- Cast<GPUDevice, O, I>(d, o, i);
- }
-};
+CAST_FUNCTORS(GPUDevice);
#define DEFINE(O, I) template struct CastFunctor<GPUDevice, O, I>
+
#define DEFINE_ALL_FROM(in_type) \
DEFINE(in_type, bool); \
DEFINE(in_type, uint8); \
@@ -59,14 +56,43 @@ DEFINE_ALL_FROM(int8);
DEFINE_ALL_FROM(int16);
DEFINE_ALL_FROM(int32);
DEFINE_ALL_FROM(int64);
-DEFINE_ALL_FROM(Eigen::half);
-DEFINE_ALL_FROM(float);
DEFINE_ALL_FROM(double);
-DEFINE_ALL_FROM(std::complex<float>);
DEFINE_ALL_FROM(std::complex<double>);
-DEFINE(bfloat16, float);
DEFINE(float, bfloat16);
+#define DEFINE_ALL_TO_FLOAT(out_type) \
+ DEFINE(out_type, bool); \
+ DEFINE(out_type, uint8); \
+ DEFINE(out_type, uint16); \
+ DEFINE(out_type, uint32); \
+ DEFINE(out_type, uint64); \
+ DEFINE(out_type, int8); \
+ DEFINE(out_type, int16); \
+ DEFINE(out_type, int32); \
+ DEFINE(out_type, int64); \
+ DEFINE(out_type, Eigen::half); \
+ DEFINE(out_type, float); \
+ DEFINE(out_type, std::complex<float>)
+
+#define DEFINE_ALL_TO_HALF(out_type) \
+ DEFINE(out_type, bool); \
+ DEFINE(out_type, uint8); \
+ DEFINE(out_type, uint16); \
+ DEFINE(out_type, uint32); \
+ DEFINE(out_type, uint64); \
+ DEFINE(out_type, int8); \
+ DEFINE(out_type, int16); \
+ DEFINE(out_type, int32); \
+ DEFINE(out_type, int64); \
+ DEFINE(out_type, Eigen::half)
+
+DEFINE_ALL_TO_HALF(Eigen::half);
+DEFINE_ALL_TO_HALF(bfloat16);
+DEFINE_ALL_TO_FLOAT(float);
+DEFINE_ALL_TO_FLOAT(std::complex<float>);
+
+#undef DEFINE_ALL_TO_FLOAT
+#undef DEFINE_ALL_TO_HALF
#undef DEFINE_ALL_FROM
#undef DEFINE
diff --git a/tensorflow/core/kernels/cast_op_impl.h b/tensorflow/core/kernels/cast_op_impl.h
index fe821b25df..b899bac681 100644
--- a/tensorflow/core/kernels/cast_op_impl.h
+++ b/tensorflow/core/kernels/cast_op_impl.h
@@ -25,22 +25,10 @@ namespace tensorflow {
namespace functor {
-template <typename O, typename I>
-struct CastFunctor<Eigen::ThreadPoolDevice, O, I> {
- void operator()(const Eigen::ThreadPoolDevice& d, typename TTypes<O>::Flat o,
- typename TTypes<I>::ConstFlat i) {
- o.device(d) = i.template cast<O>();
- }
-};
+CAST_FUNCTORS(Eigen::ThreadPoolDevice);
#ifdef TENSORFLOW_USE_SYCL
-template <typename O, typename I>
-struct CastFunctor<Eigen::SyclDevice, O, I> {
- void operator()(const Eigen::SyclDevice& d, typename TTypes<O>::Flat o,
- typename TTypes<I>::ConstFlat i) {
- o.device(d) = i.template cast<O>();
- }
-};
+CAST_FUNCTORS(Eigen::SyclDevice);
#endif // TENSORFLOW_USE_SYCL
} // namespace functor
@@ -68,139 +56,103 @@ struct CastFunctor<Eigen::SyclDevice, O, I> {
CURRY_TYPES3_NO_BF16(FN, arg0, arg1) \
FN(arg0, arg1, bfloat16);
-#define CAST_CASE(DEVICE, IN, OUT) \
- if (DataTypeToEnum<OUT>::value == dst_dtype) { \
- return [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) { \
- functor::CastFunctor<DEVICE, OUT, IN> func; \
- func(ctx->eigen_device<DEVICE>(), out->flat<OUT>(), inp.flat<IN>()); \
- }; \
+#define CAST_CASE(DEVICE, IN, OUT) \
+ if (DataTypeToEnum<OUT>::value == dst_dtype) { \
+ return [](OpKernelContext* ctx, const Tensor& inp, Tensor* out, \
+ bool truncate) { \
+ functor::CastFunctor<DEVICE, OUT, IN> func; \
+ func(ctx->eigen_device<DEVICE>(), out->flat<OUT>(), inp.flat<IN>(), \
+ truncate); \
+ }; \
}
// The functions below are implemented in the cast_op_impl_*.cc files.
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromBool(DataType dst_dtype);
+CastFunctorType GetCpuCastFromBool(DataType dst_dtype);
+
+CastFunctorType GetCpuCastFromUint8(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromUint8(DataType dst_dtype);
+CastFunctorType GetCpuCastFromUint16(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromUint16(DataType dst_dtype);
+CastFunctorType GetCpuCastFromInt8(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromUint32(DataType dst_dtype);
+CastFunctorType GetCpuCastFromUint32(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromUint64(DataType dst_dtype);
+CastFunctorType GetCpuCastFromUint64(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromInt8(DataType dst_dtype);
+CastFunctorType GetCpuCastFromInt8(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromInt16(DataType dst_dtype);
+CastFunctorType GetCpuCastFromInt16(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromInt32(DataType dst_dtype);
+CastFunctorType GetCpuCastFromInt32(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromInt64(DataType dst_dtype);
+CastFunctorType GetCpuCastFromInt64(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromHalf(DataType dst_dtype);
+CastFunctorType GetCpuCastFromHalf(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromFloat(DataType dst_dtype);
+CastFunctorType GetCpuCastFromFloat(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromDouble(DataType dst_dtype);
+CastFunctorType GetCpuCastFromDouble(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromComplex64(DataType dst_dtype);
+CastFunctorType GetCpuCastFromComplex64(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromComplex128(DataType dst_dtype);
+CastFunctorType GetCpuCastFromComplex128(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromBfloat(DataType dst_dtype);
+CastFunctorType GetCpuCastFromBfloat(DataType dst_dtype);
#if GOOGLE_CUDA
// Same, for GPU.
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromBool(DataType dst_dtype);
+CastFunctorType GetGpuCastFromBool(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromUint8(DataType dst_dtype);
+CastFunctorType GetGpuCastFromUint8(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromUint16(DataType dst_dtype);
+CastFunctorType GetGpuCastFromUint16(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromUint32(DataType dst_dtype);
+CastFunctorType GetGpuCastFromInt8(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromUint64(DataType dst_dtype);
+CastFunctorType GetGpuCastFromUint32(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromInt8(DataType dst_dtype);
+CastFunctorType GetGpuCastFromUint64(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromInt16(DataType dst_dtype);
+CastFunctorType GetGpuCastFromInt16(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromInt32(DataType dst_dtype);
+CastFunctorType GetGpuCastFromInt32(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromInt64(DataType dst_dtype);
+CastFunctorType GetGpuCastFromInt64(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromHalf(DataType dst_dtype);
+CastFunctorType GetGpuCastFromHalf(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromFloat(DataType dst_dtype);
+CastFunctorType GetGpuCastFromFloat(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromDouble(DataType dst_dtype);
+CastFunctorType GetGpuCastFromDouble(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromComplex64(DataType dst_dtype);
+CastFunctorType GetGpuCastFromComplex64(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromComplex128(DataType dst_dtype);
+CastFunctorType GetGpuCastFromComplex128(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromBfloat(DataType dst_dtype);
+CastFunctorType GetGpuCastFromBfloat(DataType dst_dtype);
#endif // GOOGLE_CUDA
#ifdef TENSORFLOW_USE_SYCL
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromBool(DataType dst_dtype);
+CastFunctorType GetSyclCastFromBool(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromUint8(DataType dst_dtype);
+CastFunctorType GetSyclCastFromUint8(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromUint16(DataType dst_dtype);
+CastFunctorType GetSyclCastFromUint16(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromUint32(DataType dst_dtype);
+CastFunctorType GetSyclCastFromUint32(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromUint64(DataType dst_dtype);
+CastFunctorType GetSyclCastFromUint64(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromInt16(DataType dst_dtype);
+CastFunctorType GetSyclCastFromInt16(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromInt32(DataType dst_dtype);
+CastFunctorType GetSyclCastFromInt32(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromInt64(DataType dst_dtype);
+CastFunctorType GetSyclCastFromInt64(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromFloat(DataType dst_dtype);
+CastFunctorType GetSyclCastFromFloat(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromDouble(DataType dst_dtype);
+CastFunctorType GetSyclCastFromDouble(DataType dst_dtype);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cast_op_impl_bfloat.cc b/tensorflow/core/kernels/cast_op_impl_bfloat.cc
index bfa7ba0d47..96aae15608 100644
--- a/tensorflow/core/kernels/cast_op_impl_bfloat.cc
+++ b/tensorflow/core/kernels/cast_op_impl_bfloat.cc
@@ -22,20 +22,19 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromBfloat(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromBfloat(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, bfloat16);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromBfloat(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromBfloat(DataType dst_dtype) {
if (dst_dtype == DT_FLOAT) {
- return [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) {
+ return [](OpKernelContext* ctx, const Tensor& inp, Tensor* out,
+ bool truncate) {
functor::CastFunctor<GPUDevice, float, bfloat16> func;
func(ctx->eigen_device<GPUDevice>(), out->flat<float>(),
- inp.flat<bfloat16>());
+ inp.flat<bfloat16>(), truncate);
};
}
return nullptr;
diff --git a/tensorflow/core/kernels/cast_op_impl_bool.cc b/tensorflow/core/kernels/cast_op_impl_bool.cc
index c5c7394b43..792d4781f2 100644
--- a/tensorflow/core/kernels/cast_op_impl_bool.cc
+++ b/tensorflow/core/kernels/cast_op_impl_bool.cc
@@ -20,15 +20,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromBool(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromBool(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, bool);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromBool(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromBool(DataType dst_dtype) {
CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, bool);
return nullptr;
}
@@ -36,8 +34,7 @@ GetGpuCastFromBool(DataType dst_dtype) {
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromBool(DataType dst_dtype) {
+CastFunctorType GetSyclCastFromBool(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, bool);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_impl_complex128.cc b/tensorflow/core/kernels/cast_op_impl_complex128.cc
index 52899d58cd..9a184e5954 100644
--- a/tensorflow/core/kernels/cast_op_impl_complex128.cc
+++ b/tensorflow/core/kernels/cast_op_impl_complex128.cc
@@ -20,15 +20,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromComplex128(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromComplex128(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, std::complex<double>);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromComplex128(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromComplex128(DataType dst_dtype) {
CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, std::complex<double>);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_impl_complex64.cc b/tensorflow/core/kernels/cast_op_impl_complex64.cc
index 617bda53d5..77bc620b46 100644
--- a/tensorflow/core/kernels/cast_op_impl_complex64.cc
+++ b/tensorflow/core/kernels/cast_op_impl_complex64.cc
@@ -20,15 +20,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromComplex64(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromComplex64(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, std::complex<float>);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromComplex64(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromComplex64(DataType dst_dtype) {
CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, std::complex<float>);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_impl_double.cc b/tensorflow/core/kernels/cast_op_impl_double.cc
index 7dc485ddad..ff9056897f 100644
--- a/tensorflow/core/kernels/cast_op_impl_double.cc
+++ b/tensorflow/core/kernels/cast_op_impl_double.cc
@@ -20,15 +20,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromDouble(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromDouble(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, double);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromDouble(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromDouble(DataType dst_dtype) {
CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, double);
return nullptr;
}
@@ -36,8 +34,7 @@ GetGpuCastFromDouble(DataType dst_dtype) {
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromDouble(DataType dst_dtype) {
+CastFunctorType GetSyclCastFromDouble(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, double);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_impl_float.cc b/tensorflow/core/kernels/cast_op_impl_float.cc
index 1c933914fd..f1e8f0e37b 100644
--- a/tensorflow/core/kernels/cast_op_impl_float.cc
+++ b/tensorflow/core/kernels/cast_op_impl_float.cc
@@ -22,15 +22,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromFloat(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromFloat(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, float);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromFloat(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromFloat(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, GPUDevice, float);
return nullptr;
}
@@ -38,8 +36,7 @@ GetGpuCastFromFloat(DataType dst_dtype) {
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromFloat(DataType dst_dtype) {
+CastFunctorType GetSyclCastFromFloat(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, float);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_impl_half.cc b/tensorflow/core/kernels/cast_op_impl_half.cc
index ef4b94e326..5da3a01352 100644
--- a/tensorflow/core/kernels/cast_op_impl_half.cc
+++ b/tensorflow/core/kernels/cast_op_impl_half.cc
@@ -20,15 +20,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromHalf(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromHalf(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, Eigen::half);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromHalf(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromHalf(DataType dst_dtype) {
CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, Eigen::half);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_impl_int16.cc b/tensorflow/core/kernels/cast_op_impl_int16.cc
index 59360f7445..440ee88fb5 100644
--- a/tensorflow/core/kernels/cast_op_impl_int16.cc
+++ b/tensorflow/core/kernels/cast_op_impl_int16.cc
@@ -20,15 +20,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromInt16(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromInt16(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, int16);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromInt16(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromInt16(DataType dst_dtype) {
CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, int16);
return nullptr;
}
@@ -36,8 +34,7 @@ GetGpuCastFromInt16(DataType dst_dtype) {
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromInt16(DataType dst_dtype) {
+CastFunctorType GetSyclCastFromInt16(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, int16);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_impl_int32.cc b/tensorflow/core/kernels/cast_op_impl_int32.cc
index a867392fde..4b3e7efddc 100644
--- a/tensorflow/core/kernels/cast_op_impl_int32.cc
+++ b/tensorflow/core/kernels/cast_op_impl_int32.cc
@@ -20,15 +20,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromInt32(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromInt32(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, int32);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromInt32(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromInt32(DataType dst_dtype) {
CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, int32);
return nullptr;
}
@@ -36,8 +34,7 @@ GetGpuCastFromInt32(DataType dst_dtype) {
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromInt32(DataType dst_dtype) {
+CastFunctorType GetSyclCastFromInt32(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, int32);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_impl_int64.cc b/tensorflow/core/kernels/cast_op_impl_int64.cc
index 467a8f6c89..0f711aa560 100644
--- a/tensorflow/core/kernels/cast_op_impl_int64.cc
+++ b/tensorflow/core/kernels/cast_op_impl_int64.cc
@@ -20,15 +20,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromInt64(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromInt64(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, int64);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromInt64(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromInt64(DataType dst_dtype) {
CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, int64);
return nullptr;
}
@@ -36,8 +34,7 @@ GetGpuCastFromInt64(DataType dst_dtype) {
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromInt64(DataType dst_dtype) {
+CastFunctorType GetSyclCastFromInt64(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, int64);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_impl_int8.cc b/tensorflow/core/kernels/cast_op_impl_int8.cc
index 21002a4321..eac185d5a0 100644
--- a/tensorflow/core/kernels/cast_op_impl_int8.cc
+++ b/tensorflow/core/kernels/cast_op_impl_int8.cc
@@ -20,15 +20,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromInt8(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromInt8(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, int8);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromInt8(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromInt8(DataType dst_dtype) {
CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, int8);
return nullptr;
}
@@ -36,8 +34,7 @@ GetGpuCastFromInt8(DataType dst_dtype) {
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromInt8(DataType dst_dtype) {
+CastFunctorType GetSyclCastFromInt8(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, int8);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_impl_uint16.cc b/tensorflow/core/kernels/cast_op_impl_uint16.cc
index cd829bae2a..3aebbdc1f3 100644
--- a/tensorflow/core/kernels/cast_op_impl_uint16.cc
+++ b/tensorflow/core/kernels/cast_op_impl_uint16.cc
@@ -20,15 +20,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromUint16(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromUint16(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, uint16);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromUint16(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromUint16(DataType dst_dtype) {
CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, uint16);
return nullptr;
}
@@ -36,8 +34,7 @@ GetGpuCastFromUint16(DataType dst_dtype) {
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromUint16(DataType dst_dtype) {
+CastFunctorType GetSyclCastFromUint16(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, uint16);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_impl_uint32.cc b/tensorflow/core/kernels/cast_op_impl_uint32.cc
index d1a854d98b..86f5961bcc 100644
--- a/tensorflow/core/kernels/cast_op_impl_uint32.cc
+++ b/tensorflow/core/kernels/cast_op_impl_uint32.cc
@@ -20,15 +20,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromUint32(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromUint32(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, uint32);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromUint32(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromUint32(DataType dst_dtype) {
CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, uint32);
return nullptr;
}
@@ -36,8 +34,7 @@ GetGpuCastFromUint32(DataType dst_dtype) {
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromUint32(DataType dst_dtype) {
+CastFunctorType GetSyclCastFromUint32(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, uint32);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_impl_uint64.cc b/tensorflow/core/kernels/cast_op_impl_uint64.cc
index 604e0424fc..6478c266ee 100644
--- a/tensorflow/core/kernels/cast_op_impl_uint64.cc
+++ b/tensorflow/core/kernels/cast_op_impl_uint64.cc
@@ -20,15 +20,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromUint64(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromUint64(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, uint64);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromUint64(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromUint64(DataType dst_dtype) {
CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, uint64);
return nullptr;
}
@@ -36,8 +34,7 @@ GetGpuCastFromUint64(DataType dst_dtype) {
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromUint64(DataType dst_dtype) {
+CastFunctorType GetSyclCastFromUint64(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, uint64);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_impl_uint8.cc b/tensorflow/core/kernels/cast_op_impl_uint8.cc
index 2d1a6f3a4e..b22547a23e 100644
--- a/tensorflow/core/kernels/cast_op_impl_uint8.cc
+++ b/tensorflow/core/kernels/cast_op_impl_uint8.cc
@@ -20,15 +20,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromUint8(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromUint8(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, uint8);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromUint8(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromUint8(DataType dst_dtype) {
CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, uint8);
return nullptr;
}
@@ -36,8 +34,7 @@ GetGpuCastFromUint8(DataType dst_dtype) {
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromUint8(DataType dst_dtype) {
+CastFunctorType GetSyclCastFromUint8(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, uint8);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_test.cc b/tensorflow/core/kernels/cast_op_test.cc
index 9bbf7afb16..cb305de5e3 100644
--- a/tensorflow/core/kernels/cast_op_test.cc
+++ b/tensorflow/core/kernels/cast_op_test.cc
@@ -40,17 +40,27 @@ static Graph* Cast(int num) {
class CastOpTest : public OpsTestBase {
protected:
- void MakeOp(DataType src, DataType dst) {
- TF_EXPECT_OK(NodeDefBuilder("cast_op", "Cast")
- .Input(FakeInput(src))
- .Attr("SrcT", src)
- .Attr("DstT", dst)
- .Finalize(node_def()));
+ void MakeOp(DataType src, DataType dst, bool trunc = false) {
+ if (trunc) {
+ TF_EXPECT_OK(NodeDefBuilder("cast_op", "Cast")
+ .Input(FakeInput(src))
+ .Attr("SrcT", src)
+ .Attr("DstT", dst)
+ .Attr("Truncate", true)
+ .Finalize(node_def()));
+ } else {
+ TF_EXPECT_OK(NodeDefBuilder("cast_op", "Cast")
+ .Input(FakeInput(src))
+ .Attr("SrcT", src)
+ .Attr("DstT", dst)
+ .Finalize(node_def()));
+ }
+
TF_EXPECT_OK(InitOp());
}
template <typename INPUT, typename OUTPUT>
- void CheckCast() {
+ void CheckCast(bool trunc = false) {
DataType in_type = DataTypeToEnum<INPUT>::v();
DataType out_type = DataTypeToEnum<OUTPUT>::v();
MakeOp(in_type, out_type);
@@ -64,8 +74,9 @@ class CastOpTest : public OpsTestBase {
}
};
-#define TEST_CAST(in, out) \
- TEST_F(CastOpTest, TestCast##_##in##_##out) { CheckCast<in, out>(); }
+#define TEST_CAST(in, out) \
+ TEST_F(CastOpTest, TestCast##_##in##_##out) { CheckCast<in, out>(); } \
+ TEST_F(CastOpTest, TestCast2##_##in##_##out) { CheckCast<in, out>(true); }
#define TEST_ALL_CASTS_FROM(in) \
TEST_CAST(in, uint8); \
diff --git a/tensorflow/core/lib/bfloat16/bfloat16.h b/tensorflow/core/lib/bfloat16/bfloat16.h
index 1c130ba300..d6f3f26cd5 100644
--- a/tensorflow/core/lib/bfloat16/bfloat16.h
+++ b/tensorflow/core/lib/bfloat16/bfloat16.h
@@ -45,17 +45,25 @@ typedef std::complex<double> complex128;
struct bfloat16 {
B16_DEVICE_FUNC bfloat16() {}
- B16_DEVICE_FUNC explicit bfloat16(const float v) {
+ B16_DEVICE_FUNC static bfloat16 truncate_to_bfloat16(const float v) {
+ bfloat16 output;
if (float_isnan(v)) {
- value = NAN_VALUE;
- return;
+ output.value = NAN_VALUE;
+ return output;
}
const uint16_t* p = reinterpret_cast<const uint16_t*>(&v);
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
- value = p[0];
+ output.value = p[0];
#else
- value = p[1];
+ output.value = p[1];
#endif
+ return output;
+ }
+
+ B16_DEVICE_FUNC explicit bfloat16(const float v) {
+ // TODO(asabne) : change the below line to
+ // value = round_to_bfloat16(v).value;
+ value = truncate_to_bfloat16(v).value;
}
B16_DEVICE_FUNC explicit bfloat16(const double val)
@@ -169,8 +177,6 @@ struct bfloat16 {
// Converts a float point to bfloat16, with round-nearest-to-even as rounding
// method.
- // TODO(b/69266521): Add a truncate_to_bfloat16 function and make this
- // function as default behavior.
// TODO: There is a slightly faster implementation (8% faster on CPU)
// than this (documented in cl/175987786), that is exponentially harder to
// understand and document. Switch to the faster version when converting to
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 386ae9635a..77697756c4 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -114,6 +114,7 @@ REGISTER_OP("Cast")
.Output("y: DstT")
.Attr("SrcT: type")
.Attr("DstT: type")
+ .Attr("Truncate: bool = false")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("_HostCast")
diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc
index cefd5b1206..15d2ccf9d2 100644
--- a/tensorflow/python/eager/pywrap_tensor.cc
+++ b/tensorflow/python/eager/pywrap_tensor.cc
@@ -154,6 +154,7 @@ TFE_TensorHandle* EagerCast(TFE_Context* ctx, TFE_TensorHandle* handle,
if (TF_GetCode(out_status) != TF_OK) RETURN_ERROR
TFE_OpSetAttrType(op, "SrcT", src_type_enum);
TFE_OpSetAttrType(op, "DstT", dst_type_enum);
+ TFE_OpSetAttrBool(op, "Truncate", false);
TFE_TensorHandle* output = nullptr;
int num_outputs = 1;
TFE_Execute(op, &output, &num_outputs, out_status);