aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/core/framework/bfloat16_test.cc3
-rw-r--r--tensorflow/core/graph/graph_partition.cc8
-rw-r--r--tensorflow/core/kernels/cast_op.cc4
-rw-r--r--tensorflow/core/kernels/cast_op.h164
-rw-r--r--tensorflow/core/kernels/cast_op_gpu.cu.cc48
-rw-r--r--tensorflow/core/kernels/cast_op_impl.h152
-rw-r--r--tensorflow/core/kernels/cast_op_impl_bfloat.cc11
-rw-r--r--tensorflow/core/kernels/cast_op_impl_bool.cc9
-rw-r--r--tensorflow/core/kernels/cast_op_impl_complex128.cc6
-rw-r--r--tensorflow/core/kernels/cast_op_impl_complex64.cc6
-rw-r--r--tensorflow/core/kernels/cast_op_impl_double.cc9
-rw-r--r--tensorflow/core/kernels/cast_op_impl_float.cc9
-rw-r--r--tensorflow/core/kernels/cast_op_impl_half.cc6
-rw-r--r--tensorflow/core/kernels/cast_op_impl_int16.cc9
-rw-r--r--tensorflow/core/kernels/cast_op_impl_int32.cc9
-rw-r--r--tensorflow/core/kernels/cast_op_impl_int64.cc9
-rw-r--r--tensorflow/core/kernels/cast_op_impl_int8.cc9
-rw-r--r--tensorflow/core/kernels/cast_op_impl_uint16.cc9
-rw-r--r--tensorflow/core/kernels/cast_op_impl_uint32.cc9
-rw-r--r--tensorflow/core/kernels/cast_op_impl_uint64.cc9
-rw-r--r--tensorflow/core/kernels/cast_op_impl_uint8.cc9
-rw-r--r--tensorflow/core/kernels/cast_op_test.cc29
-rw-r--r--tensorflow/core/lib/bfloat16/bfloat16.h20
-rw-r--r--tensorflow/core/ops/math_ops.cc1
-rw-r--r--tensorflow/python/eager/pywrap_tensor.cc1
25 files changed, 342 insertions, 216 deletions
diff --git a/tensorflow/core/framework/bfloat16_test.cc b/tensorflow/core/framework/bfloat16_test.cc
index 206396a25a..0a1b5e1975 100644
--- a/tensorflow/core/framework/bfloat16_test.cc
+++ b/tensorflow/core/framework/bfloat16_test.cc
@@ -45,7 +45,8 @@ class Bfloat16Test : public ::testing::Test,
public ::testing::WithParamInterface<Bfloat16TestParam> {};
TEST_P(Bfloat16Test, TruncateTest) {
- bfloat16 truncated(GetParam().input);
+ bfloat16 truncated = bfloat16::truncate_to_bfloat16((GetParam().input));
+
if (std::isnan(GetParam().input)) {
EXPECT_TRUE(std::isnan(float(truncated)) || std::isinf(float(truncated)));
return;
diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc
index 1b1941f9c1..ea0a814ab8 100644
--- a/tensorflow/core/graph/graph_partition.cc
+++ b/tensorflow/core/graph/graph_partition.cc
@@ -214,6 +214,14 @@ NodeDef* AddSend(const PartitionOptions& opts, const GraphInfo& g_info,
cast_builder.Attr("_start_time", start_time);
}
cast_builder.Attr("DstT", cast_dtype);
+
+ if (cast_dtype == DT_BFLOAT16) {
+ // the below attribute specifies that the cast to bfloat16 should use
+ // truncation. This is needed to retain legacy behavior when we change
+ // the default bfloat16 casts to use rounding instead of truncation
+ cast_builder.Attr("Truncate", true);
+ }
+
NodeDef* cast = gdef->add_node();
*status = cast_builder.Finalize(cast);
if (!status->ok()) return nullptr;
diff --git a/tensorflow/core/kernels/cast_op.cc b/tensorflow/core/kernels/cast_op.cc
index b4c97df38b..0478c93280 100644
--- a/tensorflow/core/kernels/cast_op.cc
+++ b/tensorflow/core/kernels/cast_op.cc
@@ -59,6 +59,8 @@ CastOpBase::CastOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT", &external_dst_dtype_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("Truncate", &use_truncation_));
+
// Quantized data types use the same underlying format as their non quantized
// version so we use the non quantized implementation for casting.
if (external_dst_dtype_ == DT_QUINT8) {
@@ -100,7 +102,7 @@ void CastOpBase::Compute(OpKernelContext* ctx) {
Tensor* out = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in.shape(), &out));
out->set_dtype(dst_dtype_);
- work_(ctx, in, out);
+ work_(ctx, in, out, use_truncation_);
out->set_dtype(external_dst_dtype_);
}
}
diff --git a/tensorflow/core/kernels/cast_op.h b/tensorflow/core/kernels/cast_op.h
index aae1e7ff19..527ab528c9 100644
--- a/tensorflow/core/kernels/cast_op.h
+++ b/tensorflow/core/kernels/cast_op.h
@@ -24,8 +24,71 @@ limitations under the License.
#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/types.h"
+// Note that the GPU cast functor templates need to be instantiated unlike the
+// CPU ones, and hence their specializations are different than that for CPUs.
+#ifdef SPECIALIZE_FOR_GPUS
+#define SPECIALIZE_CAST(DEVICE, OUT_TYPE, IN_OUT) \
+ template <typename Device> \
+ struct CastFunctor<Device, OUT_TYPE, IN_OUT> { \
+ void operator()(const Device& d, \
+ typename TTypes<OUT_TYPE>::Flat out_tensor, \
+ typename TTypes<IN_OUT>::ConstFlat in_tensor, \
+ bool truncate = false) { \
+ if (truncate) { \
+ out_tensor.device(d) = \
+ in_tensor.unaryExpr(LSBZeroSetter<IN_OUT, OUT_TYPE>()) \
+ .template cast<OUT_TYPE>(); \
+ } else { \
+ out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>(); \
+ } \
+ } \
+ }; \
+ template struct CastFunctor<DEVICE, OUT_TYPE, IN_OUT>;
+#else
+#define SPECIALIZE_CAST(DEVICE, OUT_TYPE, IN_OUT) \
+ template <> \
+ struct CastFunctor<DEVICE, OUT_TYPE, IN_OUT> { \
+ void operator()(const DEVICE& d, \
+ typename TTypes<OUT_TYPE>::Flat out_tensor, \
+ typename TTypes<IN_OUT>::ConstFlat in_tensor, \
+ bool truncate = false) { \
+ if (truncate) { \
+ out_tensor.device(d) = \
+ in_tensor.unaryExpr(LSBZeroSetter<IN_OUT, OUT_TYPE>()) \
+ .template cast<OUT_TYPE>(); \
+ } else { \
+ out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>(); \
+ } \
+ } \
+ };
+#endif
+
+#define CAST_FUNCTORS(devname) \
+ SPECIALIZE_CAST(devname, float, double) \
+ SPECIALIZE_CAST(devname, float, std::complex<double>) \
+ SPECIALIZE_CAST(devname, std::complex<float>, std::complex<double>) \
+ SPECIALIZE_CAST(devname, std::complex<float>, double) \
+ SPECIALIZE_CAST(devname, Eigen::half, double) \
+ SPECIALIZE_CAST(devname, Eigen::half, float) \
+ SPECIALIZE_CAST(devname, Eigen::half, std::complex<double>) \
+ SPECIALIZE_CAST(devname, Eigen::half, std::complex<float>) \
+ SPECIALIZE_CAST(devname, bfloat16, float) \
+ template <typename OUT_TYPE, typename IN_OUT> \
+ struct CastFunctor<devname, OUT_TYPE, IN_OUT> { \
+ void operator()(const devname& d, \
+ typename TTypes<OUT_TYPE>::Flat out_tensor, \
+ typename TTypes<IN_OUT>::ConstFlat in_tensor, \
+ bool truncate = false) { \
+ out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>(); \
+ } \
+ };
+
namespace tensorflow {
+typedef std::function<void(OpKernelContext*, const Tensor&, Tensor*,
+ bool trunc)>
+ CastFunctorType;
+
// Common base class of Cast kernels
class CastOpBase : public OpKernel {
public:
@@ -38,8 +101,8 @@ class CastOpBase : public OpKernel {
DataType dst_dtype_;
DataType external_src_dtype_;
DataType external_dst_dtype_;
- std::function<void(OpKernelContext*, const Tensor&, Tensor*)> work_ = nullptr;
-
+ bool use_truncation_;
+ CastFunctorType work_ = nullptr;
Status Unimplemented();
TF_DISALLOW_COPY_AND_ASSIGN(CastOpBase);
@@ -56,6 +119,23 @@ class CpuCastOp : public CastOpBase {
namespace functor {
+template <typename I>
+constexpr int MantissaWidth() {
+ return std::numeric_limits<I>::digits;
+}
+
+template <>
+constexpr int MantissaWidth<Eigen::half>() {
+ // Remember, there's 1 hidden bit
+ return 10 + 1;
+}
+
+template <>
+constexpr int MantissaWidth<bfloat16>() {
+ // Remember, there's 1 hidden bit
+ return 7 + 1;
+}
+
template <typename Device, typename Tout, typename Tin>
void Cast(const Device& d, typename TTypes<Tout>::Flat o,
typename TTypes<Tin>::ConstFlat i) {
@@ -65,7 +145,85 @@ void Cast(const Device& d, typename TTypes<Tout>::Flat o,
template <typename Device, typename Tout, typename Tin>
struct CastFunctor {
void operator()(const Device& d, typename TTypes<Tout>::Flat o,
- typename TTypes<Tin>::ConstFlat i);
+ typename TTypes<Tin>::ConstFlat i, bool truncate = false);
+};
+
+// Only enable LSBZeroSetterHelper for 64 and 32 bit input data types.
+// Specialize for others if needed in future.
+template <typename I>
+typename std::enable_if<sizeof(I) == 8, void>::type EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE static LSBZeroSetterHelper(I& t, int n) {
+ // Only zero the bits for non-NaNs.
+ // For NaNs, let the non-truncation version handle it.
+ if (!std::isnan(t)) {
+ uint64_t* p = reinterpret_cast<uint64_t*>(&t);
+ *p &= (0xFFFFFFFFFFFFFFFF << n);
+ }
+}
+
+template <typename I>
+typename std::enable_if<sizeof(I) == 4, void>::type EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE static LSBZeroSetterHelper(I& t, int n) {
+ // Only zero the bits for non-NaNs.
+ // For NaNs, let the non-truncation version handle it.
+ if (!std::isnan(t)) {
+ uint32_t* p = reinterpret_cast<uint32_t*>(&t);
+ *p &= (0xFFFFFFFF << n);
+ }
+}
+
+// Set n least significant bits to 0
+template <typename I, typename O>
+struct LSBZeroSetter {
+ EIGEN_EMPTY_STRUCT_CTOR(LSBZeroSetter)
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const I operator()(const I& a) const {
+ constexpr int bits = MantissaWidth<I>() - MantissaWidth<O>();
+ static_assert(
+ bits > 0,
+ "The output type must have fewer mantissa bits than the input type\n");
+ I t = a;
+ LSBZeroSetterHelper(t, bits);
+ return t;
+ }
+};
+
+template <typename I, typename O>
+struct LSBZeroSetter<std::complex<I>, std::complex<O>> {
+ EIGEN_EMPTY_STRUCT_CTOR(LSBZeroSetter)
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<I> operator()(
+ const std::complex<I>& a) const {
+ constexpr int bits = MantissaWidth<I>() - MantissaWidth<O>();
+ static_assert(
+ bits > 0,
+ "The output type must have fewer mantissa bits than the input type\n");
+ I re = std::real(a);
+ I img = std::imag(a);
+ LSBZeroSetterHelper(re, bits);
+ LSBZeroSetterHelper(img, bits);
+ std::complex<I> toReturn(re, img);
+ return toReturn;
+ }
+};
+
+template <typename I, typename O>
+struct LSBZeroSetter<std::complex<I>, O> {
+ EIGEN_EMPTY_STRUCT_CTOR(LSBZeroSetter)
+ // Sets the 16 LSBits of the float to 0
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<I> operator()(
+ const std::complex<I>& a) const {
+ constexpr int bits = MantissaWidth<I>() - MantissaWidth<O>();
+ static_assert(
+ bits > 0,
+ "The output type must have fewer mantissa bits than the input type\n");
+ I re = std::real(a);
+ I img = std::imag(a);
+ LSBZeroSetterHelper(re, bits);
+ LSBZeroSetterHelper(img, bits);
+ std::complex<I> toReturn(re, img);
+ return toReturn;
+ }
};
} // end namespace functor
diff --git a/tensorflow/core/kernels/cast_op_gpu.cu.cc b/tensorflow/core/kernels/cast_op_gpu.cu.cc
index 607e7f5efd..036996fca2 100644
--- a/tensorflow/core/kernels/cast_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/cast_op_gpu.cu.cc
@@ -18,22 +18,19 @@ limitations under the License.
#define EIGEN_USE_GPU
#include "tensorflow/core/framework/bfloat16.h"
+#define SPECIALIZE_FOR_GPUS
#include "tensorflow/core/kernels/cast_op.h"
+#undef SPECIALIZE_FOR_GPUS
namespace tensorflow {
namespace functor {
typedef Eigen::GpuDevice GPUDevice;
-template <typename O, typename I>
-struct CastFunctor<GPUDevice, O, I> {
- void operator()(const GPUDevice& d, typename TTypes<O>::Flat o,
- typename TTypes<I>::ConstFlat i) {
- Cast<GPUDevice, O, I>(d, o, i);
- }
-};
+CAST_FUNCTORS(GPUDevice);
#define DEFINE(O, I) template struct CastFunctor<GPUDevice, O, I>
+
#define DEFINE_ALL_FROM(in_type) \
DEFINE(in_type, bool); \
DEFINE(in_type, uint8); \
@@ -59,14 +56,43 @@ DEFINE_ALL_FROM(int8);
DEFINE_ALL_FROM(int16);
DEFINE_ALL_FROM(int32);
DEFINE_ALL_FROM(int64);
-DEFINE_ALL_FROM(Eigen::half);
-DEFINE_ALL_FROM(float);
DEFINE_ALL_FROM(double);
-DEFINE_ALL_FROM(std::complex<float>);
DEFINE_ALL_FROM(std::complex<double>);
-DEFINE(bfloat16, float);
DEFINE(float, bfloat16);
+#define DEFINE_ALL_TO_FLOAT(out_type) \
+ DEFINE(out_type, bool); \
+ DEFINE(out_type, uint8); \
+ DEFINE(out_type, uint16); \
+ DEFINE(out_type, uint32); \
+ DEFINE(out_type, uint64); \
+ DEFINE(out_type, int8); \
+ DEFINE(out_type, int16); \
+ DEFINE(out_type, int32); \
+ DEFINE(out_type, int64); \
+ DEFINE(out_type, Eigen::half); \
+ DEFINE(out_type, float); \
+ DEFINE(out_type, std::complex<float>)
+
+#define DEFINE_ALL_TO_HALF(out_type) \
+ DEFINE(out_type, bool); \
+ DEFINE(out_type, uint8); \
+ DEFINE(out_type, uint16); \
+ DEFINE(out_type, uint32); \
+ DEFINE(out_type, uint64); \
+ DEFINE(out_type, int8); \
+ DEFINE(out_type, int16); \
+ DEFINE(out_type, int32); \
+ DEFINE(out_type, int64); \
+ DEFINE(out_type, Eigen::half)
+
+DEFINE_ALL_TO_HALF(Eigen::half);
+DEFINE_ALL_TO_HALF(bfloat16);
+DEFINE_ALL_TO_FLOAT(float);
+DEFINE_ALL_TO_FLOAT(std::complex<float>);
+
+#undef DEFINE_ALL_TO_FLOAT
+#undef DEFINE_ALL_TO_HALF
#undef DEFINE_ALL_FROM
#undef DEFINE
diff --git a/tensorflow/core/kernels/cast_op_impl.h b/tensorflow/core/kernels/cast_op_impl.h
index fe821b25df..b899bac681 100644
--- a/tensorflow/core/kernels/cast_op_impl.h
+++ b/tensorflow/core/kernels/cast_op_impl.h
@@ -25,22 +25,10 @@ namespace tensorflow {
namespace functor {
-template <typename O, typename I>
-struct CastFunctor<Eigen::ThreadPoolDevice, O, I> {
- void operator()(const Eigen::ThreadPoolDevice& d, typename TTypes<O>::Flat o,
- typename TTypes<I>::ConstFlat i) {
- o.device(d) = i.template cast<O>();
- }
-};
+CAST_FUNCTORS(Eigen::ThreadPoolDevice);
#ifdef TENSORFLOW_USE_SYCL
-template <typename O, typename I>
-struct CastFunctor<Eigen::SyclDevice, O, I> {
- void operator()(const Eigen::SyclDevice& d, typename TTypes<O>::Flat o,
- typename TTypes<I>::ConstFlat i) {
- o.device(d) = i.template cast<O>();
- }
-};
+CAST_FUNCTORS(Eigen::SyclDevice);
#endif // TENSORFLOW_USE_SYCL
} // namespace functor
@@ -68,139 +56,103 @@ struct CastFunctor<Eigen::SyclDevice, O, I> {
CURRY_TYPES3_NO_BF16(FN, arg0, arg1) \
FN(arg0, arg1, bfloat16);
-#define CAST_CASE(DEVICE, IN, OUT) \
- if (DataTypeToEnum<OUT>::value == dst_dtype) { \
- return [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) { \
- functor::CastFunctor<DEVICE, OUT, IN> func; \
- func(ctx->eigen_device<DEVICE>(), out->flat<OUT>(), inp.flat<IN>()); \
- }; \
+#define CAST_CASE(DEVICE, IN, OUT) \
+ if (DataTypeToEnum<OUT>::value == dst_dtype) { \
+ return [](OpKernelContext* ctx, const Tensor& inp, Tensor* out, \
+ bool truncate) { \
+ functor::CastFunctor<DEVICE, OUT, IN> func; \
+ func(ctx->eigen_device<DEVICE>(), out->flat<OUT>(), inp.flat<IN>(), \
+ truncate); \
+ }; \
}
// The functions below are implemented in the cast_op_impl_*.cc files.
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromBool(DataType dst_dtype);
+CastFunctorType GetCpuCastFromBool(DataType dst_dtype);
+
+CastFunctorType GetCpuCastFromUint8(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromUint8(DataType dst_dtype);
+CastFunctorType GetCpuCastFromUint16(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromUint16(DataType dst_dtype);
+CastFunctorType GetCpuCastFromInt8(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromUint32(DataType dst_dtype);
+CastFunctorType GetCpuCastFromUint32(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromUint64(DataType dst_dtype);
+CastFunctorType GetCpuCastFromUint64(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromInt8(DataType dst_dtype);
+CastFunctorType GetCpuCastFromInt8(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromInt16(DataType dst_dtype);
+CastFunctorType GetCpuCastFromInt16(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromInt32(DataType dst_dtype);
+CastFunctorType GetCpuCastFromInt32(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromInt64(DataType dst_dtype);
+CastFunctorType GetCpuCastFromInt64(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromHalf(DataType dst_dtype);
+CastFunctorType GetCpuCastFromHalf(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromFloat(DataType dst_dtype);
+CastFunctorType GetCpuCastFromFloat(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromDouble(DataType dst_dtype);
+CastFunctorType GetCpuCastFromDouble(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromComplex64(DataType dst_dtype);
+CastFunctorType GetCpuCastFromComplex64(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromComplex128(DataType dst_dtype);
+CastFunctorType GetCpuCastFromComplex128(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromBfloat(DataType dst_dtype);
+CastFunctorType GetCpuCastFromBfloat(DataType dst_dtype);
#if GOOGLE_CUDA
// Same, for GPU.
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromBool(DataType dst_dtype);
+CastFunctorType GetGpuCastFromBool(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromUint8(DataType dst_dtype);
+CastFunctorType GetGpuCastFromUint8(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromUint16(DataType dst_dtype);
+CastFunctorType GetGpuCastFromUint16(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromUint32(DataType dst_dtype);
+CastFunctorType GetGpuCastFromInt8(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromUint64(DataType dst_dtype);
+CastFunctorType GetGpuCastFromUint32(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromInt8(DataType dst_dtype);
+CastFunctorType GetGpuCastFromUint64(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromInt16(DataType dst_dtype);
+CastFunctorType GetGpuCastFromInt16(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromInt32(DataType dst_dtype);
+CastFunctorType GetGpuCastFromInt32(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromInt64(DataType dst_dtype);
+CastFunctorType GetGpuCastFromInt64(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromHalf(DataType dst_dtype);
+CastFunctorType GetGpuCastFromHalf(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromFloat(DataType dst_dtype);
+CastFunctorType GetGpuCastFromFloat(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromDouble(DataType dst_dtype);
+CastFunctorType GetGpuCastFromDouble(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromComplex64(DataType dst_dtype);
+CastFunctorType GetGpuCastFromComplex64(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromComplex128(DataType dst_dtype);
+CastFunctorType GetGpuCastFromComplex128(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromBfloat(DataType dst_dtype);
+CastFunctorType GetGpuCastFromBfloat(DataType dst_dtype);
#endif // GOOGLE_CUDA
#ifdef TENSORFLOW_USE_SYCL
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromBool(DataType dst_dtype);
+CastFunctorType GetSyclCastFromBool(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromUint8(DataType dst_dtype);
+CastFunctorType GetSyclCastFromUint8(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromUint16(DataType dst_dtype);
+CastFunctorType GetSyclCastFromUint16(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromUint32(DataType dst_dtype);
+CastFunctorType GetSyclCastFromUint32(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromUint64(DataType dst_dtype);
+CastFunctorType GetSyclCastFromUint64(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromInt16(DataType dst_dtype);
+CastFunctorType GetSyclCastFromInt16(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromInt32(DataType dst_dtype);
+CastFunctorType GetSyclCastFromInt32(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromInt64(DataType dst_dtype);
+CastFunctorType GetSyclCastFromInt64(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromFloat(DataType dst_dtype);
+CastFunctorType GetSyclCastFromFloat(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromDouble(DataType dst_dtype);
+CastFunctorType GetSyclCastFromDouble(DataType dst_dtype);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cast_op_impl_bfloat.cc b/tensorflow/core/kernels/cast_op_impl_bfloat.cc
index bfa7ba0d47..96aae15608 100644
--- a/tensorflow/core/kernels/cast_op_impl_bfloat.cc
+++ b/tensorflow/core/kernels/cast_op_impl_bfloat.cc
@@ -22,20 +22,19 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromBfloat(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromBfloat(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, bfloat16);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromBfloat(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromBfloat(DataType dst_dtype) {
if (dst_dtype == DT_FLOAT) {
- return [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) {
+ return [](OpKernelContext* ctx, const Tensor& inp, Tensor* out,
+ bool truncate) {
functor::CastFunctor<GPUDevice, float, bfloat16> func;
func(ctx->eigen_device<GPUDevice>(), out->flat<float>(),
- inp.flat<bfloat16>());
+ inp.flat<bfloat16>(), truncate);
};
}
return nullptr;
diff --git a/tensorflow/core/kernels/cast_op_impl_bool.cc b/tensorflow/core/kernels/cast_op_impl_bool.cc
index c5c7394b43..792d4781f2 100644
--- a/tensorflow/core/kernels/cast_op_impl_bool.cc
+++ b/tensorflow/core/kernels/cast_op_impl_bool.cc
@@ -20,15 +20,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromBool(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromBool(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, bool);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromBool(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromBool(DataType dst_dtype) {
CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, bool);
return nullptr;
}
@@ -36,8 +34,7 @@ GetGpuCastFromBool(DataType dst_dtype) {
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromBool(DataType dst_dtype) {
+CastFunctorType GetSyclCastFromBool(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, bool);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_impl_complex128.cc b/tensorflow/core/kernels/cast_op_impl_complex128.cc
index 52899d58cd..9a184e5954 100644
--- a/tensorflow/core/kernels/cast_op_impl_complex128.cc
+++ b/tensorflow/core/kernels/cast_op_impl_complex128.cc
@@ -20,15 +20,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromComplex128(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromComplex128(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, std::complex<double>);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromComplex128(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromComplex128(DataType dst_dtype) {
CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, std::complex<double>);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_impl_complex64.cc b/tensorflow/core/kernels/cast_op_impl_complex64.cc
index 617bda53d5..77bc620b46 100644
--- a/tensorflow/core/kernels/cast_op_impl_complex64.cc
+++ b/tensorflow/core/kernels/cast_op_impl_complex64.cc
@@ -20,15 +20,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromComplex64(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromComplex64(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, std::complex<float>);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromComplex64(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromComplex64(DataType dst_dtype) {
CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, std::complex<float>);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_impl_double.cc b/tensorflow/core/kernels/cast_op_impl_double.cc
index 7dc485ddad..ff9056897f 100644
--- a/tensorflow/core/kernels/cast_op_impl_double.cc
+++ b/tensorflow/core/kernels/cast_op_impl_double.cc
@@ -20,15 +20,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromDouble(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromDouble(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, double);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromDouble(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromDouble(DataType dst_dtype) {
CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, double);
return nullptr;
}
@@ -36,8 +34,7 @@ GetGpuCastFromDouble(DataType dst_dtype) {
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromDouble(DataType dst_dtype) {
+CastFunctorType GetSyclCastFromDouble(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, double);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_impl_float.cc b/tensorflow/core/kernels/cast_op_impl_float.cc
index 1c933914fd..f1e8f0e37b 100644
--- a/tensorflow/core/kernels/cast_op_impl_float.cc
+++ b/tensorflow/core/kernels/cast_op_impl_float.cc
@@ -22,15 +22,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromFloat(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromFloat(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, float);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromFloat(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromFloat(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, GPUDevice, float);
return nullptr;
}
@@ -38,8 +36,7 @@ GetGpuCastFromFloat(DataType dst_dtype) {
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromFloat(DataType dst_dtype) {
+CastFunctorType GetSyclCastFromFloat(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, float);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_impl_half.cc b/tensorflow/core/kernels/cast_op_impl_half.cc
index ef4b94e326..5da3a01352 100644
--- a/tensorflow/core/kernels/cast_op_impl_half.cc
+++ b/tensorflow/core/kernels/cast_op_impl_half.cc
@@ -20,15 +20,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromHalf(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromHalf(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, Eigen::half);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromHalf(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromHalf(DataType dst_dtype) {
CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, Eigen::half);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_impl_int16.cc b/tensorflow/core/kernels/cast_op_impl_int16.cc
index 59360f7445..440ee88fb5 100644
--- a/tensorflow/core/kernels/cast_op_impl_int16.cc
+++ b/tensorflow/core/kernels/cast_op_impl_int16.cc
@@ -20,15 +20,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromInt16(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromInt16(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, int16);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromInt16(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromInt16(DataType dst_dtype) {
CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, int16);
return nullptr;
}
@@ -36,8 +34,7 @@ GetGpuCastFromInt16(DataType dst_dtype) {
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromInt16(DataType dst_dtype) {
+CastFunctorType GetSyclCastFromInt16(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, int16);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_impl_int32.cc b/tensorflow/core/kernels/cast_op_impl_int32.cc
index a867392fde..4b3e7efddc 100644
--- a/tensorflow/core/kernels/cast_op_impl_int32.cc
+++ b/tensorflow/core/kernels/cast_op_impl_int32.cc
@@ -20,15 +20,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromInt32(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromInt32(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, int32);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromInt32(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromInt32(DataType dst_dtype) {
CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, int32);
return nullptr;
}
@@ -36,8 +34,7 @@ GetGpuCastFromInt32(DataType dst_dtype) {
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromInt32(DataType dst_dtype) {
+CastFunctorType GetSyclCastFromInt32(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, int32);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_impl_int64.cc b/tensorflow/core/kernels/cast_op_impl_int64.cc
index 467a8f6c89..0f711aa560 100644
--- a/tensorflow/core/kernels/cast_op_impl_int64.cc
+++ b/tensorflow/core/kernels/cast_op_impl_int64.cc
@@ -20,15 +20,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromInt64(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromInt64(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, int64);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromInt64(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromInt64(DataType dst_dtype) {
CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, int64);
return nullptr;
}
@@ -36,8 +34,7 @@ GetGpuCastFromInt64(DataType dst_dtype) {
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromInt64(DataType dst_dtype) {
+CastFunctorType GetSyclCastFromInt64(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, int64);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_impl_int8.cc b/tensorflow/core/kernels/cast_op_impl_int8.cc
index 21002a4321..eac185d5a0 100644
--- a/tensorflow/core/kernels/cast_op_impl_int8.cc
+++ b/tensorflow/core/kernels/cast_op_impl_int8.cc
@@ -20,15 +20,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromInt8(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromInt8(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, int8);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromInt8(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromInt8(DataType dst_dtype) {
CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, int8);
return nullptr;
}
@@ -36,8 +34,7 @@ GetGpuCastFromInt8(DataType dst_dtype) {
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromInt8(DataType dst_dtype) {
+CastFunctorType GetSyclCastFromInt8(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, int8);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_impl_uint16.cc b/tensorflow/core/kernels/cast_op_impl_uint16.cc
index cd829bae2a..3aebbdc1f3 100644
--- a/tensorflow/core/kernels/cast_op_impl_uint16.cc
+++ b/tensorflow/core/kernels/cast_op_impl_uint16.cc
@@ -20,15 +20,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromUint16(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromUint16(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, uint16);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromUint16(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromUint16(DataType dst_dtype) {
CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, uint16);
return nullptr;
}
@@ -36,8 +34,7 @@ GetGpuCastFromUint16(DataType dst_dtype) {
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromUint16(DataType dst_dtype) {
+CastFunctorType GetSyclCastFromUint16(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, uint16);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_impl_uint32.cc b/tensorflow/core/kernels/cast_op_impl_uint32.cc
index d1a854d98b..86f5961bcc 100644
--- a/tensorflow/core/kernels/cast_op_impl_uint32.cc
+++ b/tensorflow/core/kernels/cast_op_impl_uint32.cc
@@ -20,15 +20,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromUint32(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromUint32(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, uint32);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromUint32(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromUint32(DataType dst_dtype) {
CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, uint32);
return nullptr;
}
@@ -36,8 +34,7 @@ GetGpuCastFromUint32(DataType dst_dtype) {
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromUint32(DataType dst_dtype) {
+CastFunctorType GetSyclCastFromUint32(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, uint32);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_impl_uint64.cc b/tensorflow/core/kernels/cast_op_impl_uint64.cc
index 604e0424fc..6478c266ee 100644
--- a/tensorflow/core/kernels/cast_op_impl_uint64.cc
+++ b/tensorflow/core/kernels/cast_op_impl_uint64.cc
@@ -20,15 +20,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromUint64(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromUint64(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, uint64);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromUint64(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromUint64(DataType dst_dtype) {
CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, uint64);
return nullptr;
}
@@ -36,8 +34,7 @@ GetGpuCastFromUint64(DataType dst_dtype) {
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromUint64(DataType dst_dtype) {
+CastFunctorType GetSyclCastFromUint64(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, uint64);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_impl_uint8.cc b/tensorflow/core/kernels/cast_op_impl_uint8.cc
index 2d1a6f3a4e..b22547a23e 100644
--- a/tensorflow/core/kernels/cast_op_impl_uint8.cc
+++ b/tensorflow/core/kernels/cast_op_impl_uint8.cc
@@ -20,15 +20,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromUint8(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromUint8(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, uint8);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromUint8(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromUint8(DataType dst_dtype) {
CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, uint8);
return nullptr;
}
@@ -36,8 +34,7 @@ GetGpuCastFromUint8(DataType dst_dtype) {
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromUint8(DataType dst_dtype) {
+CastFunctorType GetSyclCastFromUint8(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, uint8);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_test.cc b/tensorflow/core/kernels/cast_op_test.cc
index 9bbf7afb16..cb305de5e3 100644
--- a/tensorflow/core/kernels/cast_op_test.cc
+++ b/tensorflow/core/kernels/cast_op_test.cc
@@ -40,17 +40,27 @@ static Graph* Cast(int num) {
class CastOpTest : public OpsTestBase {
protected:
- void MakeOp(DataType src, DataType dst) {
- TF_EXPECT_OK(NodeDefBuilder("cast_op", "Cast")
- .Input(FakeInput(src))
- .Attr("SrcT", src)
- .Attr("DstT", dst)
- .Finalize(node_def()));
+ void MakeOp(DataType src, DataType dst, bool trunc = false) {
+ if (trunc) {
+ TF_EXPECT_OK(NodeDefBuilder("cast_op", "Cast")
+ .Input(FakeInput(src))
+ .Attr("SrcT", src)
+ .Attr("DstT", dst)
+ .Attr("Truncate", true)
+ .Finalize(node_def()));
+ } else {
+ TF_EXPECT_OK(NodeDefBuilder("cast_op", "Cast")
+ .Input(FakeInput(src))
+ .Attr("SrcT", src)
+ .Attr("DstT", dst)
+ .Finalize(node_def()));
+ }
+
TF_EXPECT_OK(InitOp());
}
template <typename INPUT, typename OUTPUT>
- void CheckCast() {
+ void CheckCast(bool trunc = false) {
DataType in_type = DataTypeToEnum<INPUT>::v();
DataType out_type = DataTypeToEnum<OUTPUT>::v();
MakeOp(in_type, out_type);
@@ -64,8 +74,9 @@ class CastOpTest : public OpsTestBase {
}
};
-#define TEST_CAST(in, out) \
- TEST_F(CastOpTest, TestCast##_##in##_##out) { CheckCast<in, out>(); }
+#define TEST_CAST(in, out) \
+ TEST_F(CastOpTest, TestCast##_##in##_##out) { CheckCast<in, out>(); } \
+ TEST_F(CastOpTest, TestCast2##_##in##_##out) { CheckCast<in, out>(true); }
#define TEST_ALL_CASTS_FROM(in) \
TEST_CAST(in, uint8); \
diff --git a/tensorflow/core/lib/bfloat16/bfloat16.h b/tensorflow/core/lib/bfloat16/bfloat16.h
index 1c130ba300..d6f3f26cd5 100644
--- a/tensorflow/core/lib/bfloat16/bfloat16.h
+++ b/tensorflow/core/lib/bfloat16/bfloat16.h
@@ -45,17 +45,25 @@ typedef std::complex<double> complex128;
struct bfloat16 {
B16_DEVICE_FUNC bfloat16() {}
- B16_DEVICE_FUNC explicit bfloat16(const float v) {
+ B16_DEVICE_FUNC static bfloat16 truncate_to_bfloat16(const float v) {
+ bfloat16 output;
if (float_isnan(v)) {
- value = NAN_VALUE;
- return;
+ output.value = NAN_VALUE;
+ return output;
}
const uint16_t* p = reinterpret_cast<const uint16_t*>(&v);
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
- value = p[0];
+ output.value = p[0];
#else
- value = p[1];
+ output.value = p[1];
#endif
+ return output;
+ }
+
+ B16_DEVICE_FUNC explicit bfloat16(const float v) {
+ // TODO(asabne) : change the below line to
+ // value = round_to_bfloat16(v).value;
+ value = truncate_to_bfloat16(v).value;
}
B16_DEVICE_FUNC explicit bfloat16(const double val)
@@ -169,8 +177,6 @@ struct bfloat16 {
// Converts a float point to bfloat16, with round-nearest-to-even as rounding
// method.
- // TODO(b/69266521): Add a truncate_to_bfloat16 function and make this
- // function as default behavior.
// TODO: There is a slightly faster implementation (8% faster on CPU)
// than this (documented in cl/175987786), that is exponentially harder to
// understand and document. Switch to the faster version when converting to
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 386ae9635a..77697756c4 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -114,6 +114,7 @@ REGISTER_OP("Cast")
.Output("y: DstT")
.Attr("SrcT: type")
.Attr("DstT: type")
+ .Attr("Truncate: bool = false")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("_HostCast")
diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc
index cefd5b1206..15d2ccf9d2 100644
--- a/tensorflow/python/eager/pywrap_tensor.cc
+++ b/tensorflow/python/eager/pywrap_tensor.cc
@@ -154,6 +154,7 @@ TFE_TensorHandle* EagerCast(TFE_Context* ctx, TFE_TensorHandle* handle,
if (TF_GetCode(out_status) != TF_OK) RETURN_ERROR
TFE_OpSetAttrType(op, "SrcT", src_type_enum);
TFE_OpSetAttrType(op, "DstT", dst_type_enum);
+ TFE_OpSetAttrBool(op, "Truncate", false);
TFE_TensorHandle* output = nullptr;
int num_outputs = 1;
TFE_Execute(op, &output, &num_outputs, out_status);