diff options
author | 2018-01-16 15:52:12 -0800 | |
---|---|---|
committer | 2018-01-16 15:55:47 -0800 | |
commit | ccbd14b741e6efbe51769f0f1b9cb3719c42c23b (patch) | |
tree | bcc79094e48982780b084cd19110d903186eea8f | |
parent | 287fe4f2404a7b69ffce89cc41ff3f049ca7a08b (diff) |
Enable bfloat16 for CPU kernels
PiperOrigin-RevId: 182124532
29 files changed, 373 insertions, 212 deletions
diff --git a/tensorflow/contrib/batching/util/BUILD b/tensorflow/contrib/batching/util/BUILD index 0df7d456da..2a84a7712a 100644 --- a/tensorflow/contrib/batching/util/BUILD +++ b/tensorflow/contrib/batching/util/BUILD @@ -26,6 +26,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//tensorflow/core/kernels/batching_util:periodic_function_dynamic", + "//third_party/eigen3", ], ) diff --git a/tensorflow/contrib/ffmpeg/default/BUILD b/tensorflow/contrib/ffmpeg/default/BUILD index 949ae9ad9e..6b455567d7 100644 --- a/tensorflow/contrib/ffmpeg/default/BUILD +++ b/tensorflow/contrib/ffmpeg/default/BUILD @@ -19,6 +19,7 @@ cc_library( ], deps = [ "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", "@protobuf_archive//:protobuf_headers", ], ) diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/BUILD b/tensorflow/contrib/tensor_forest/kernels/v4/BUILD index b7876e1df6..794b76d858 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/BUILD +++ b/tensorflow/contrib/tensor_forest/kernels/v4/BUILD @@ -302,6 +302,7 @@ cc_library( "//tensorflow/contrib/tensor_forest/proto:fertile_stats_proto_cc", ], [ + "//third_party/eigen3", "//tensorflow/contrib/decision_trees/proto:generic_tree_model_cc_headers_only", "//tensorflow/contrib/tensor_forest/proto:fertile_stats_proto_cc_headers_only", ], @@ -322,6 +323,7 @@ cc_library( srcs = ["params.cc"], hdrs = ["params.h"], deps = [ + "//third_party/eigen3", "//tensorflow/core:framework_headers_lib", ] + if_static( [ diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 54565e826e..370d32ba18 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -277,6 +277,7 @@ cc_library( "platform/platform.h", "platform/protobuf.h", "platform/types.h", + "lib/bfloat16/bfloat16.h", ] + tf_additional_proto_hdrs() + glob(tf_env_time_hdrs()), copts = tf_copts(), deps = tf_lib_proto_parsing_deps(), @@ -289,6 +290,7 @@ cc_library( cc_library( name = "lib", hdrs = [ + "lib/bfloat16/bfloat16.h", "lib/core/arena.h", "lib/core/bitmap.h", "lib/core/bits.h", @@ -560,6 +562,7 @@ cc_library( "framework/numeric_types.h", "framework/tensor_types.h", "framework/type_traits.h", + "lib/bfloat16/bfloat16.h", "platform/default/dynamic_annotations.h", "platform/default/integral_types.h", "platform/default/logging.h", @@ -1589,6 +1592,7 @@ cc_library( "platform/jpeg.h", ]), hdrs = [ + "lib/bfloat16/bfloat16.h", "lib/core/stringpiece.h", "lib/jpeg/jpeg_handle.h", "lib/jpeg/jpeg_mem.h", @@ -1616,6 +1620,7 @@ cc_library( "platform/gif.h", ]), hdrs = [ + "lib/bfloat16/bfloat16.h", "lib/core/stringpiece.h", "lib/gif/gif_io.h", "lib/gtl/cleanup.h", @@ -1643,6 +1648,7 @@ cc_library( "platform/png.h", ]), hdrs = [ + "lib/bfloat16/bfloat16.h", "lib/core/casts.h", "lib/core/stringpiece.h", "lib/png/png_io.h", diff --git a/tensorflow/core/framework/numeric_types.h b/tensorflow/core/framework/numeric_types.h index 42752a49bb..988a18da0e 100644 --- a/tensorflow/core/framework/numeric_types.h +++ b/tensorflow/core/framework/numeric_types.h @@ -41,198 +41,39 @@ typedef Eigen::QInt32 qint32; typedef Eigen::QInt16 qint16; typedef Eigen::QUInt16 quint16; -// see framework/bfloat16.h for description. -struct bfloat16 { - EIGEN_DEVICE_FUNC bfloat16() {} - - EIGEN_DEVICE_FUNC explicit bfloat16(const float v) { - if (Eigen::numext::isnan(v)) { - value = NAN_VALUE; - return; - } - const uint16_t* p = reinterpret_cast<const uint16_t*>(&v); -#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ - value = p[0]; -#else - value = p[1]; -#endif - } - - // Following the convention of numpy, converting between complex and - // float will lead to loss of imag value. - explicit EIGEN_DEVICE_FUNC bfloat16(const complex64& val) - : bfloat16(val.real()) {} - - explicit EIGEN_DEVICE_FUNC bfloat16(const complex128& val) - : bfloat16(static_cast<float>(val.real())) {} - - template <class T> - explicit EIGEN_DEVICE_FUNC bfloat16(const T& val) - : bfloat16(static_cast<float>(val)) {} - - EIGEN_DEVICE_FUNC explicit operator float() const { - float result; - - uint16_t* q = reinterpret_cast<uint16_t*>(&result); - -#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ - q[0] = value; - q[1] = 0; -#else - q[0] = 0; - q[1] = value; -#endif - return result; - } - - EIGEN_DEVICE_FUNC explicit operator bool() const { - return static_cast<bool>(float(*this)); - } - - EIGEN_DEVICE_FUNC explicit operator Eigen::half() const { - return static_cast<Eigen::half>(float(*this)); - } - - EIGEN_DEVICE_FUNC explicit operator short() const { - return static_cast<short>(float(*this)); - } - - EIGEN_DEVICE_FUNC explicit operator int() const { - return static_cast<int>(float(*this)); - } - - EIGEN_DEVICE_FUNC explicit operator long() const { - return static_cast<long>(float(*this)); - } - - EIGEN_DEVICE_FUNC explicit operator char() const { - return static_cast<char>(float(*this)); - } - - EIGEN_DEVICE_FUNC explicit operator signed char() const { - return static_cast<signed char>(float(*this)); - } - - EIGEN_DEVICE_FUNC explicit operator unsigned char() const { - return static_cast<unsigned char>(float(*this)); - } - - EIGEN_DEVICE_FUNC explicit operator unsigned int() const { - return static_cast<unsigned int>(float(*this)); - } - - EIGEN_DEVICE_FUNC explicit operator unsigned long() const { - return static_cast<unsigned long>(float(*this)); - } - - EIGEN_DEVICE_FUNC explicit operator unsigned long long() const { - return static_cast<unsigned long long>(float(*this)); - } - - EIGEN_DEVICE_FUNC explicit operator long long() const { - return static_cast<long long>(float(*this)); - } - - EIGEN_DEVICE_FUNC explicit operator double() const { - return static_cast<double>(float(*this)); - } +} // namespace tensorflow - EIGEN_DEVICE_FUNC explicit operator complex64() const { - return complex64(float(*this), float(0.0)); - } - - EIGEN_DEVICE_FUNC explicit operator complex128() const { - return complex128(double(*this), double(0.0)); - } +namespace Eigen { +// TOOD(xpan): We probably need to overwrite more methods to have correct eigen +// behavior. E.g. loest(), is_integer, etc. See NumTraits.h in eigen. +template <> +struct NumTraits<tensorflow::bfloat16> + : GenericNumTraits<tensorflow::bfloat16> {}; - static bfloat16 epsilon() { - bfloat16 x; - x.value = 0x3c00; // 0x1.0p-7 - return x; - } +using ::tensorflow::operator==; +using ::tensorflow::operator!=; - uint16_t value; +namespace numext { - // A value that represents "not a number". - static const uint16_t NAN_VALUE = 0x7FC0; -}; - -inline bfloat16 operator+(bfloat16 a, bfloat16 b) { - return bfloat16(static_cast<float>(a) + static_cast<float>(b)); -} -inline bfloat16 operator-(bfloat16 a, bfloat16 b) { - return bfloat16(static_cast<float>(a) - static_cast<float>(b)); -} -inline bfloat16 operator*(bfloat16 a, bfloat16 b) { - return bfloat16(static_cast<float>(a) * static_cast<float>(b)); -} -inline bfloat16 operator/(bfloat16 a, bfloat16 b) { - return bfloat16(static_cast<float>(a) / static_cast<float>(b)); -} -inline bfloat16 operator-(bfloat16 a) { - a.value ^= 0x8000; - return a; -} -inline bool operator<(bfloat16 a, bfloat16 b) { - return static_cast<float>(a) < static_cast<float>(b); -} -inline bool operator<=(bfloat16 a, bfloat16 b) { - return static_cast<float>(a) <= static_cast<float>(b); -} -inline bool operator==(bfloat16 a, bfloat16 b) { - return static_cast<float>(a) == static_cast<float>(b); -} -inline bool operator!=(bfloat16 a, bfloat16 b) { - return static_cast<float>(a) != static_cast<float>(b); -} -inline bool operator>(bfloat16 a, bfloat16 b) { - return static_cast<float>(a) > static_cast<float>(b); -} -inline bool operator>=(bfloat16 a, bfloat16 b) { - return static_cast<float>(a) >= static_cast<float>(b); -} -inline bfloat16& operator+=(bfloat16& a, bfloat16 b) { - a = a + b; - return a; -} -inline bfloat16& operator-=(bfloat16& a, bfloat16 b) { - a = a - b; - return a; -} -inline bfloat16 operator++(bfloat16& a) { - a += bfloat16(1); - return a; -} -inline bfloat16 operator--(bfloat16& a) { - a -= bfloat16(1); - return a; -} -inline bfloat16 operator++(bfloat16& a, int) { - bfloat16 original_value = a; - ++a; - return original_value; -} -inline bfloat16 operator--(bfloat16& a, int) { - bfloat16 original_value = a; - --a; - return original_value; -} -inline bfloat16& operator*=(bfloat16& a, bfloat16 b) { - a = a * b; - return a; +template <> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE tensorflow::bfloat16 log( + const tensorflow::bfloat16& x) { + return static_cast<tensorflow::bfloat16>(::logf(static_cast<float>(x))); } -inline bfloat16& operator/=(bfloat16& a, bfloat16 b) { - a = a / b; - return a; + +template <> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE tensorflow::bfloat16 exp( + const tensorflow::bfloat16& x) { + return static_cast<tensorflow::bfloat16>(::expf(static_cast<float>(x))); } -} // end namespace tensorflow -namespace Eigen { template <> -struct NumTraits<tensorflow::bfloat16> : GenericNumTraits<uint16_t> {}; +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE tensorflow::bfloat16 abs( + const tensorflow::bfloat16& x) { + return static_cast<tensorflow::bfloat16>(::fabsf(static_cast<float>(x))); +} -using ::tensorflow::operator==; -using ::tensorflow::operator!=; +} // namespace numext } // namespace Eigen #ifdef COMPILER_MSVC diff --git a/tensorflow/core/framework/register_types.h b/tensorflow/core/framework/register_types.h index 4bb37e4f6e..41563c464d 100644 --- a/tensorflow/core/framework/register_types.h +++ b/tensorflow/core/framework/register_types.h @@ -155,11 +155,16 @@ limitations under the License. TF_CALL_uint8(m) TF_CALL_int8(m) #define TF_CALL_REAL_NUMBER_TYPES(m) \ + TF_CALL_INTEGRAL_TYPES(m) \ + TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m) + +#define TF_CALL_REAL_NUMBER_TYPES_NO_BFLOAT16(m) \ TF_CALL_INTEGRAL_TYPES(m) TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m) -#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \ - TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m) TF_CALL_int64(m) \ - TF_CALL_uint16(m) TF_CALL_int16(m) TF_CALL_uint8(m) TF_CALL_int8(m) +#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \ + TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m) \ + TF_CALL_int64(m) TF_CALL_uint16(m) TF_CALL_int16(m) TF_CALL_uint8(m) \ + TF_CALL_int8(m) // Call "m" for all number types, including complex64 and complex128. #define TF_CALL_NUMBER_TYPES(m) \ @@ -194,6 +199,13 @@ limitations under the License. #define TF_CALL_QUANTIZED_TYPES(m) \ TF_CALL_qint8(m) TF_CALL_quint8(m) TF_CALL_qint32(m) +// Types used for save and restore ops. +#define TF_CALL_SAVE_RESTORE_TYPES(m) \ + TF_CALL_INTEGRAL_TYPES(m) \ + TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m) TF_CALL_complex64(m) \ + TF_CALL_complex128(m) TF_CALL_bool(m) TF_CALL_string(m) \ + TF_CALL_QUANTIZED_TYPES(m) + #ifdef TENSORFLOW_SYCL_NO_DOUBLE #define TF_CALL_SYCL_double(m) #else // TENSORFLOW_SYCL_NO_DOUBLE diff --git a/tensorflow/core/kernels/concat_lib_cpu.cc b/tensorflow/core/kernels/concat_lib_cpu.cc index 743e3acfd5..43731114c0 100644 --- a/tensorflow/core/kernels/concat_lib_cpu.cc +++ b/tensorflow/core/kernels/concat_lib_cpu.cc @@ -72,7 +72,6 @@ REGISTER(qint8) REGISTER(quint16) REGISTER(qint16) REGISTER(qint32) -REGISTER(bfloat16) TF_CALL_variant(REGISTER) #if defined(IS_MOBILE_PLATFORM) && !defined(SUPPORT_SELECTIVE_REGISTRATION) && \ diff --git a/tensorflow/core/kernels/concat_op.cc b/tensorflow/core/kernels/concat_op.cc index 8e480aa995..ae1b5da32e 100644 --- a/tensorflow/core/kernels/concat_op.cc +++ b/tensorflow/core/kernels/concat_op.cc @@ -172,7 +172,6 @@ REGISTER_CONCAT(qint8); REGISTER_CONCAT(quint16); REGISTER_CONCAT(qint16); REGISTER_CONCAT(qint32); -REGISTER_CONCAT(bfloat16); #undef REGISTER_CONCAT diff --git a/tensorflow/core/kernels/constant_op_gpu.cu.cc b/tensorflow/core/kernels/constant_op_gpu.cu.cc index 49beb499af..3487606778 100644 --- a/tensorflow/core/kernels/constant_op_gpu.cu.cc +++ b/tensorflow/core/kernels/constant_op_gpu.cu.cc @@ -77,7 +77,6 @@ struct FillFunctor<GPUDevice, T> { #define DEFINE_FILL_GPU(T) template struct FillFunctor<GPUDevice, T>; TF_CALL_REAL_NUMBER_TYPES(DEFINE_FILL_GPU); -TF_CALL_bfloat16(DEFINE_FILL_GPU); TF_CALL_bool(DEFINE_FILL_GPU); #undef DEFINE_FILL_GPU @@ -91,7 +90,6 @@ struct SetZeroFunctor<GPUDevice, T> { #define DEFINE_SETZERO_GPU(T) template struct SetZeroFunctor<GPUDevice, T>; TF_CALL_NUMBER_TYPES(DEFINE_SETZERO_GPU); -TF_CALL_bfloat16(DEFINE_SETZERO_GPU); TF_CALL_bool(DEFINE_SETZERO_GPU); #undef DEFINE_SETZERO_GPU @@ -105,7 +103,6 @@ struct SetOneFunctor<GPUDevice, T> { #define DEFINE_SETONE_GPU(T) template struct SetOneFunctor<GPUDevice, T>; TF_CALL_NUMBER_TYPES(DEFINE_SETONE_GPU); -TF_CALL_bfloat16(DEFINE_SETONE_GPU); TF_CALL_bool(DEFINE_SETONE_GPU); #undef DEFINE_SETONE_GPU diff --git a/tensorflow/core/kernels/cross_op.cc b/tensorflow/core/kernels/cross_op.cc index 05a33a97b4..b29524f1f9 100644 --- a/tensorflow/core/kernels/cross_op.cc +++ b/tensorflow/core/kernels/cross_op.cc @@ -105,6 +105,7 @@ TF_CALL_REAL_NUMBER_TYPES(DECLARE_GPU_KERNEL); REGISTER_KERNEL_BUILDER( \ Name("Cross").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ CrossOp<GPUDevice, type>); + TF_CALL_REAL_NUMBER_TYPES(REGISTER_GPU_KERNEL); #undef REGISTER_GPU_KERNEL #endif diff --git a/tensorflow/core/kernels/fill_functor.cc b/tensorflow/core/kernels/fill_functor.cc index 35d9693f54..bde39770de 100644 --- a/tensorflow/core/kernels/fill_functor.cc +++ b/tensorflow/core/kernels/fill_functor.cc @@ -42,6 +42,7 @@ void SetZeroFunctor<Eigen::ThreadPoolDevice, string>::operator()( template struct SetZeroFunctor<Eigen::ThreadPoolDevice, T>; DEFINE_SETZERO_CPU(bool); DEFINE_SETZERO_CPU(Eigen::half); +DEFINE_SETZERO_CPU(bfloat16); DEFINE_SETZERO_CPU(float); DEFINE_SETZERO_CPU(double); DEFINE_SETZERO_CPU(uint8); @@ -87,6 +88,7 @@ void SetOneFunctor<Eigen::ThreadPoolDevice, T>::operator()( template struct SetOneFunctor<Eigen::ThreadPoolDevice, T>; DEFINE_SETONE_CPU(bool); DEFINE_SETONE_CPU(Eigen::half); +DEFINE_SETONE_CPU(bfloat16); DEFINE_SETONE_CPU(float); DEFINE_SETONE_CPU(double); DEFINE_SETONE_CPU(uint8); diff --git a/tensorflow/core/kernels/fill_functor.cu.cc b/tensorflow/core/kernels/fill_functor.cu.cc index 49beb499af..3487606778 100644 --- a/tensorflow/core/kernels/fill_functor.cu.cc +++ b/tensorflow/core/kernels/fill_functor.cu.cc @@ -77,7 +77,6 @@ struct FillFunctor<GPUDevice, T> { #define DEFINE_FILL_GPU(T) template struct FillFunctor<GPUDevice, T>; TF_CALL_REAL_NUMBER_TYPES(DEFINE_FILL_GPU); -TF_CALL_bfloat16(DEFINE_FILL_GPU); TF_CALL_bool(DEFINE_FILL_GPU); #undef DEFINE_FILL_GPU @@ -91,7 +90,6 @@ struct SetZeroFunctor<GPUDevice, T> { #define DEFINE_SETZERO_GPU(T) template struct SetZeroFunctor<GPUDevice, T>; TF_CALL_NUMBER_TYPES(DEFINE_SETZERO_GPU); -TF_CALL_bfloat16(DEFINE_SETZERO_GPU); TF_CALL_bool(DEFINE_SETZERO_GPU); #undef DEFINE_SETZERO_GPU @@ -105,7 +103,6 @@ struct SetOneFunctor<GPUDevice, T> { #define DEFINE_SETONE_GPU(T) template struct SetOneFunctor<GPUDevice, T>; TF_CALL_NUMBER_TYPES(DEFINE_SETONE_GPU); -TF_CALL_bfloat16(DEFINE_SETONE_GPU); TF_CALL_bool(DEFINE_SETONE_GPU); #undef DEFINE_SETONE_GPU diff --git a/tensorflow/core/kernels/save_restore_tensor.cc b/tensorflow/core/kernels/save_restore_tensor.cc index 1700bcfca5..df60eda759 100644 --- a/tensorflow/core/kernels/save_restore_tensor.cc +++ b/tensorflow/core/kernels/save_restore_tensor.cc @@ -119,8 +119,7 @@ void SaveTensors( break; switch (input.dtype()) { - TF_CALL_POD_STRING_TYPES(WRITER_ADD) - TF_CALL_QUANTIZED_TYPES(WRITER_ADD) + TF_CALL_SAVE_RESTORE_TYPES(WRITER_ADD) default: context->SetStatus(errors::Unimplemented("Saving data type ", DataTypeString(input.dtype()), @@ -219,8 +218,7 @@ void RestoreTensor(OpKernelContext* context, break; switch (type) { - TF_CALL_POD_STRING_TYPES(READER_COPY) - TF_CALL_QUANTIZED_TYPES(READER_COPY) + TF_CALL_SAVE_RESTORE_TYPES(READER_COPY) default: context->SetStatus(errors::Unimplemented( "Restoring data type ", DataTypeString(type), " not yet supported")); diff --git a/tensorflow/core/kernels/slice_op.cc b/tensorflow/core/kernels/slice_op.cc index a9e31cc336..82595de779 100644 --- a/tensorflow/core/kernels/slice_op.cc +++ b/tensorflow/core/kernels/slice_op.cc @@ -439,7 +439,6 @@ namespace functor { DECLARE_CPU_SPEC(T, 7); TF_CALL_ALL_TYPES(DECLARE_FOR_N); -TF_CALL_bfloat16(DECLARE_FOR_N); #undef DECLARE_FOR_N #undef DECLARE_CPU_SPEC @@ -456,7 +455,6 @@ TF_CALL_bfloat16(DECLARE_FOR_N); TF_CALL_POD_STRING_TYPES(REGISTER_SLICE); TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE); -TF_CALL_bfloat16(REGISTER_SLICE); #undef REGISTER_SLICE #else #define REGISTER_SLICE(type) \ @@ -469,7 +467,6 @@ TF_CALL_bfloat16(REGISTER_SLICE); TF_CALL_POD_STRING_TYPES(REGISTER_SLICE); TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE); -TF_CALL_bfloat16(REGISTER_SLICE); #undef REGISTER_SLICE #endif // INTEL_MKL diff --git a/tensorflow/core/kernels/slice_op_cpu_impl.h b/tensorflow/core/kernels/slice_op_cpu_impl.h index a70805658e..58dc7df3e0 100644 --- a/tensorflow/core/kernels/slice_op_cpu_impl.h +++ b/tensorflow/core/kernels/slice_op_cpu_impl.h @@ -30,7 +30,6 @@ using CpuDevice = Eigen::ThreadPoolDevice; template struct functor::Slice<CpuDevice, T, CPU_PROVIDED_IXDIM>; TF_CALL_ALL_TYPES(DEFINE_CPU_KERNELS); -DEFINE_CPU_KERNELS(bfloat16); #undef DEFINE_CPU_KERNELS diff --git a/tensorflow/core/kernels/split_lib_cpu.cc b/tensorflow/core/kernels/split_lib_cpu.cc index 6583f96a91..25026208d1 100644 --- a/tensorflow/core/kernels/split_lib_cpu.cc +++ b/tensorflow/core/kernels/split_lib_cpu.cc @@ -41,7 +41,6 @@ void Split<Eigen::ThreadPoolDevice, T>::operator()( TF_CALL_ALL_TYPES(DEFINE_CPU_KERNELS) DEFINE_CPU_KERNELS(quint8) -DEFINE_CPU_KERNELS(bfloat16) #ifdef TENSORFLOW_USE_SYCL template <typename T> diff --git a/tensorflow/core/kernels/split_op.cc b/tensorflow/core/kernels/split_op.cc index 90d7e225ed..78badde27e 100644 --- a/tensorflow/core/kernels/split_op.cc +++ b/tensorflow/core/kernels/split_op.cc @@ -360,8 +360,6 @@ class SplitOpSYCL : public SplitOpBase<SYCLDevice, T> { TF_CALL_ALL_TYPES(REGISTER_SPLIT); REGISTER_SPLIT(quint8); -// TODO(xpan): Merge bfloat16 into TF_CALL_ALL_TYPES -REGISTER_SPLIT(bfloat16); #undef REGISTER_SPLIT diff --git a/tensorflow/core/kernels/split_v_op.cc b/tensorflow/core/kernels/split_v_op.cc index 3316e5fcc9..f1078ac349 100644 --- a/tensorflow/core/kernels/split_v_op.cc +++ b/tensorflow/core/kernels/split_v_op.cc @@ -406,7 +406,6 @@ class SplitVOpGPU : public SplitVOpBase<GPUDevice, T, Tlen> { REGISTER_SPLIT(type, int64); TF_CALL_ALL_TYPES(REGISTER_SPLIT_LEN); -REGISTER_SPLIT_LEN(bfloat16); #undef REGISTER_SPLIT_LEN #undef REGISTER_SPLIT diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc index 73b6d4cf6a..7c213e14d2 100644 --- a/tensorflow/core/kernels/strided_slice_op.cc +++ b/tensorflow/core/kernels/strided_slice_op.cc @@ -386,7 +386,6 @@ class StridedSliceAssignOp : public OpKernel { StridedSliceAssignOp<CPUDevice, type>) TF_CALL_ALL_TYPES(REGISTER_STRIDED_SLICE); -REGISTER_STRIDED_SLICE(bfloat16); #undef REGISTER_STRIDED_SLICE diff --git a/tensorflow/core/kernels/strided_slice_op_impl.h b/tensorflow/core/kernels/strided_slice_op_impl.h index afe3a051e6..a84ba38ef4 100644 --- a/tensorflow/core/kernels/strided_slice_op_impl.h +++ b/tensorflow/core/kernels/strided_slice_op_impl.h @@ -288,7 +288,6 @@ DECLARE_FOR_N_GPU(int64); #endif // END GOOGLE_CUDA TF_CALL_ALL_TYPES(DECLARE_FOR_N_CPU); -DECLARE_FOR_N_CPU(bfloat16); #ifdef TENSORFLOW_USE_SYCL #define PREVENT_FOR_N_SYCL(T) \ diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc index 66aee2dfe2..af93d814ec 100644 --- a/tensorflow/core/kernels/tensor_array_ops.cc +++ b/tensorflow/core/kernels/tensor_array_ops.cc @@ -708,7 +708,6 @@ TF_CALL_POD_STRING_TYPES(REGISTER_GATHER_AND_PACK); REGISTER_GATHER_AND_PACK(quint8); REGISTER_GATHER_AND_PACK(qint8); REGISTER_GATHER_AND_PACK(qint32); -REGISTER_GATHER_AND_PACK(bfloat16); #undef REGISTER_GATHER_AND_PACK @@ -939,7 +938,6 @@ TF_CALL_POD_STRING_TYPES(REGISTER_CONCAT); REGISTER_CONCAT(quint8); REGISTER_CONCAT(qint8); REGISTER_CONCAT(qint32); -REGISTER_CONCAT(bfloat16); #undef REGISTER_CONCAT diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc index 96c051c636..2e0d18b634 100644 --- a/tensorflow/core/kernels/transpose_op.cc +++ b/tensorflow/core/kernels/transpose_op.cc @@ -230,7 +230,6 @@ Status ConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx, .HostMemory("perm"), \ MklConjugateTransposeCpuOp); TF_CALL_ALL_TYPES(REGISTER); -REGISTER(bfloat16); #undef REGISTER #else // INTEL_MKL @@ -247,7 +246,6 @@ REGISTER(bfloat16); .HostMemory("perm"), \ ConjugateTransposeCpuOp); TF_CALL_ALL_TYPES(REGISTER) -REGISTER(bfloat16); #undef REGISTER #endif // INTEL_MKL diff --git a/tensorflow/core/lib/bfloat16/bfloat16.cc b/tensorflow/core/lib/bfloat16/bfloat16.cc new file mode 100644 index 0000000000..a591717fd1 --- /dev/null +++ b/tensorflow/core/lib/bfloat16/bfloat16.cc @@ -0,0 +1,25 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/lib/bfloat16/bfloat16.h" + +#include "third_party/eigen3/Eigen/Core" + +namespace tensorflow { + +B16_DEVICE_FUNC bfloat16::operator Eigen::half() const { + return static_cast<Eigen::half>(float(*this)); +} +} // end namespace tensorflow diff --git a/tensorflow/core/lib/bfloat16/bfloat16.h b/tensorflow/core/lib/bfloat16/bfloat16.h new file mode 100644 index 0000000000..f9cca0ef2a --- /dev/null +++ b/tensorflow/core/lib/bfloat16/bfloat16.h @@ -0,0 +1,276 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_BFLOAT16_BFLOAT16_H_ +#define TENSORFLOW_CORE_LIB_BFLOAT16_BFLOAT16_H_ + +#include <complex> + +#ifdef __CUDACC__ +// All functions callable from CUDA code must be qualified with __device__ +#define B16_DEVICE_FUNC __host__ __device__ + +#else +#define B16_DEVICE_FUNC + +#endif + +namespace Eigen { +struct half; +} + +namespace tensorflow { + +// Single precision complex. +typedef std::complex<float> complex64; +// Double precision complex. +typedef std::complex<double> complex128; + +// see framework/bfloat16.h for description. +struct bfloat16 { + B16_DEVICE_FUNC bfloat16() {} + + B16_DEVICE_FUNC explicit bfloat16(const float v) { + if (float_isnan(v)) { + value = NAN_VALUE; + return; + } + const uint16_t* p = reinterpret_cast<const uint16_t*>(&v); +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + value = p[0]; +#else + value = p[1]; +#endif + } + + B16_DEVICE_FUNC explicit bfloat16(const double val) + : bfloat16(static_cast<float>(val)) {} + // Following the convention of numpy, converting between complex and + // float will lead to loss of imag value. + B16_DEVICE_FUNC explicit bfloat16(const complex64& val) + : bfloat16(val.real()) {} + + B16_DEVICE_FUNC explicit bfloat16(const complex128& val) + : bfloat16(static_cast<float>(val.real())) {} + + B16_DEVICE_FUNC explicit bfloat16(const unsigned short val) + : bfloat16(static_cast<float>(val)) {} + + B16_DEVICE_FUNC explicit bfloat16(const unsigned int val) + : bfloat16(static_cast<float>(val)) {} + + B16_DEVICE_FUNC explicit bfloat16(const int val) + : bfloat16(static_cast<float>(val)) {} + + B16_DEVICE_FUNC explicit bfloat16(const long val) + : bfloat16(static_cast<float>(val)) {} + + B16_DEVICE_FUNC explicit bfloat16(const long long val) + : bfloat16(static_cast<float>(val)) {} + + template <class T> + B16_DEVICE_FUNC explicit bfloat16(const T& val) + : bfloat16(static_cast<float>(val)) {} + + B16_DEVICE_FUNC explicit operator float() const { + float result; + + uint16_t* q = reinterpret_cast<uint16_t*>(&result); + +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + q[0] = value; + q[1] = 0; +#else + q[0] = 0; + q[1] = value; +#endif + return result; + } + + B16_DEVICE_FUNC explicit operator bool() const { + return static_cast<bool>(float(*this)); + } + + B16_DEVICE_FUNC explicit operator Eigen::half() const; + + B16_DEVICE_FUNC explicit operator short() const { + return static_cast<short>(float(*this)); + } + + B16_DEVICE_FUNC explicit operator int() const { + return static_cast<int>(float(*this)); + } + + B16_DEVICE_FUNC explicit operator long() const { + return static_cast<long>(float(*this)); + } + + B16_DEVICE_FUNC explicit operator char() const { + return static_cast<char>(float(*this)); + } + + B16_DEVICE_FUNC explicit operator signed char() const { + return static_cast<signed char>(float(*this)); + } + + B16_DEVICE_FUNC explicit operator unsigned char() const { + return static_cast<unsigned char>(float(*this)); + } + + B16_DEVICE_FUNC explicit operator unsigned short() const { + return static_cast<unsigned short>(float(*this)); + } + + B16_DEVICE_FUNC explicit operator unsigned int() const { + return static_cast<unsigned int>(float(*this)); + } + + B16_DEVICE_FUNC explicit operator unsigned long() const { + return static_cast<unsigned long>(float(*this)); + } + + B16_DEVICE_FUNC explicit operator unsigned long long() const { + return static_cast<unsigned long long>(float(*this)); + } + + B16_DEVICE_FUNC explicit operator long long() const { + return static_cast<long long>(float(*this)); + } + + B16_DEVICE_FUNC explicit operator double() const { + return static_cast<double>(float(*this)); + } + + B16_DEVICE_FUNC explicit operator complex64() const { + return complex64(float(*this), float(0.0)); + } + + B16_DEVICE_FUNC explicit operator complex128() const { + return complex128(double(*this), double(0.0)); + } + + static bfloat16 epsilon() { + bfloat16 x; + x.value = 0x3c00; // 0x1.0p-7 + return x; + } + + uint16_t value; + + // A value that represents "not a number". + static const uint16_t NAN_VALUE = 0x7FC0; + + private: + B16_DEVICE_FUNC bool float_isnan(const float& x) { +#ifdef __CUDA_ARCH__ + return ::isnan(x); +#else + return std::isnan(x); +#endif + } +}; + +B16_DEVICE_FUNC inline std::ostream& operator<<(std::ostream& os, + const bfloat16& dt) { + os << static_cast<float>(dt); + return os; +} + +B16_DEVICE_FUNC inline bfloat16 operator+(bfloat16 a, bfloat16 b) { + return bfloat16(static_cast<float>(a) + static_cast<float>(b)); +} +B16_DEVICE_FUNC inline bfloat16 operator+(bfloat16 a, int b) { + return bfloat16(static_cast<float>(a) + static_cast<float>(b)); +} +B16_DEVICE_FUNC inline bfloat16 operator+(int a, bfloat16 b) { + return bfloat16(static_cast<float>(a) + static_cast<float>(b)); +} +B16_DEVICE_FUNC inline bfloat16 operator-(bfloat16 a, bfloat16 b) { + return bfloat16(static_cast<float>(a) - static_cast<float>(b)); +} +B16_DEVICE_FUNC inline bfloat16 operator*(bfloat16 a, bfloat16 b) { + return bfloat16(static_cast<float>(a) * static_cast<float>(b)); +} +B16_DEVICE_FUNC inline bfloat16 operator/(bfloat16 a, bfloat16 b) { + return bfloat16(static_cast<float>(a) / static_cast<float>(b)); +} +B16_DEVICE_FUNC inline bfloat16 operator-(bfloat16 a) { + a.value ^= 0x8000; + return a; +} +B16_DEVICE_FUNC inline bool operator<(bfloat16 a, bfloat16 b) { + return static_cast<float>(a) < static_cast<float>(b); +} +B16_DEVICE_FUNC inline bool operator<=(bfloat16 a, bfloat16 b) { + return static_cast<float>(a) <= static_cast<float>(b); +} +B16_DEVICE_FUNC inline bool operator==(bfloat16 a, bfloat16 b) { + return static_cast<float>(a) == static_cast<float>(b); +} +B16_DEVICE_FUNC inline bool operator!=(bfloat16 a, bfloat16 b) { + return static_cast<float>(a) != static_cast<float>(b); +} +B16_DEVICE_FUNC inline bool operator>(bfloat16 a, bfloat16 b) { + return static_cast<float>(a) > static_cast<float>(b); +} +B16_DEVICE_FUNC inline bool operator>=(bfloat16 a, bfloat16 b) { + return static_cast<float>(a) >= static_cast<float>(b); +} +B16_DEVICE_FUNC inline bfloat16& operator+=(bfloat16& a, bfloat16 b) { + a = a + b; + return a; +} +B16_DEVICE_FUNC inline bfloat16& operator-=(bfloat16& a, bfloat16 b) { + a = a - b; + return a; +} +B16_DEVICE_FUNC inline bfloat16 operator++(bfloat16& a) { + a += bfloat16(1); + return a; +} +B16_DEVICE_FUNC inline bfloat16 operator--(bfloat16& a) { + a -= bfloat16(1); + return a; +} +B16_DEVICE_FUNC inline bfloat16 operator++(bfloat16& a, int) { + bfloat16 original_value = a; + ++a; + return original_value; +} +B16_DEVICE_FUNC inline bfloat16 operator--(bfloat16& a, int) { + bfloat16 original_value = a; + --a; + return original_value; +} +B16_DEVICE_FUNC inline bfloat16& operator*=(bfloat16& a, bfloat16 b) { + a = a * b; + return a; +} +B16_DEVICE_FUNC inline bfloat16& operator/=(bfloat16& a, bfloat16 b) { + a = a / b; + return a; +} +} // end namespace tensorflow + +namespace std { +template <> +struct hash<tensorflow::bfloat16> { + size_t operator()(const tensorflow::bfloat16& v) const { + return hash<float>()(static_cast<float>(v)); + } +}; +} // namespace std + +#endif // TENSORFLOW_CORE_LIB_BFLOAT16_BFLOAT16_H_ diff --git a/tensorflow/core/lib/hash/hash.h b/tensorflow/core/lib/hash/hash.h index 0fb12966af..4d312ab7e8 100644 --- a/tensorflow/core/lib/hash/hash.h +++ b/tensorflow/core/lib/hash/hash.h @@ -65,6 +65,13 @@ struct hash<T*> { }; template <> +struct hash<bfloat16> { + size_t operator()(const bfloat16& t) const { + return std::hash<float>()(static_cast<float>(t)); + } +}; + +template <> struct hash<string> { size_t operator()(const string& s) const { return static_cast<size_t>(Hash64(s)); diff --git a/tensorflow/core/lib/strings/strcat.h b/tensorflow/core/lib/strings/strcat.h index 8e35549ed4..5835b0101d 100644 --- a/tensorflow/core/lib/strings/strcat.h +++ b/tensorflow/core/lib/strings/strcat.h @@ -119,6 +119,9 @@ class AlphaNum { AlphaNum(float f) // NOLINT(runtime/explicit) : piece_(digits_, strlen(FloatToBuffer(f, digits_))) {} + AlphaNum(bfloat16 f) // NOLINT(runtime/explicit) + : piece_(digits_, strlen(FloatToBuffer(static_cast<float>(f), digits_))) { + } AlphaNum(double f) // NOLINT(runtime/explicit) : piece_(digits_, strlen(DoubleToBuffer(f, digits_))) {} diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl index 6d83f8b7fd..e9c510c93c 100644 --- a/tensorflow/core/platform/default/build_config.bzl +++ b/tensorflow/core/platform/default/build_config.bzl @@ -510,6 +510,7 @@ def tf_additional_cloud_kernel_deps(): def tf_lib_proto_parsing_deps(): return [ ":protos_all_cc", + "//third_party/eigen3", "//tensorflow/core/platform/default/build_config:proto_parsing", ] diff --git a/tensorflow/core/platform/types.h b/tensorflow/core/platform/types.h index 6308e58847..e2dd5b003f 100644 --- a/tensorflow/core/platform/types.h +++ b/tensorflow/core/platform/types.h @@ -31,6 +31,12 @@ limitations under the License. #error Define the appropriate PLATFORM_<foo> macro for this platform #endif +#if defined(PLATFORM_WINDOWS) +#include "tensorflow/core/platform/windows/cpu_info.h" +#endif + +#include "tensorflow/core/lib/bfloat16/bfloat16.h" + namespace tensorflow { // Define tensorflow::string to refer to appropriate platform specific type. diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index d2fcef5304..6cb727ae88 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -395,6 +395,7 @@ tf_cc_shared_object( }), deps = [ "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", "@protobuf_archive//:protobuf_headers", ], ) |