diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-25 17:34:36 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-25 17:34:36 -0700 |
commit | c8eaae87533b460928f0141e14196545704fc47e (patch) | |
tree | 74191124397f612c62ecef03d0ef1e628a697fa1 | |
parent | 70d9a489537a5a3c5fe85e75dd52bdc479966992 (diff) | |
parent | 535fa4919c9e247e2df673d8af874c3a39a38976 (diff) |
Merge pull request #19533 from yongtang:19180-bfloat16
PiperOrigin-RevId: 206083795
-rw-r--r-- | tensorflow/python/framework/fast_tensor_util.pyx | 7 | ||||
-rw-r--r-- | tensorflow/python/framework/tensor_util.py | 8 |
2 files changed, 14 insertions, 1 deletions
diff --git a/tensorflow/python/framework/fast_tensor_util.pyx b/tensorflow/python/framework/fast_tensor_util.pyx index 17d112a1ec..2e3e15f53a 100644 --- a/tensorflow/python/framework/fast_tensor_util.pyx +++ b/tensorflow/python/framework/fast_tensor_util.pyx @@ -6,6 +6,13 @@ cimport numpy as np from tensorflow.python.util import compat +def AppendBFloat16ArrayToTensorProto( + tensor_proto, np.ndarray[np.uint16_t, ndim=1] nparray): + cdef long i, n + n = nparray.size + for i in range(n): + tensor_proto.half_val.append(nparray[i]) + def AppendFloat16ArrayToTensorProto( # For numpy, npy_half is a typedef for npy_uint16, diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index 8c9dfce7cc..9a0f34fad2 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -67,10 +67,16 @@ def SlowAppendBFloat16ArrayToTensorProto(tensor_proto, proto_values): [ExtractBitsFromBFloat16(x) for x in proto_values]) +def FastAppendBFloat16ArrayToTensorProto(tensor_proto, proto_values): + fast_tensor_util.AppendBFloat16ArrayToTensorProto( + tensor_proto, np.asarray( + proto_values, dtype=dtypes.bfloat16.as_numpy_dtype).view(np.uint16)) + + if _FAST_TENSOR_UTIL_AVAILABLE: _NP_TO_APPEND_FN = { dtypes.bfloat16.as_numpy_dtype: - SlowAppendBFloat16ArrayToTensorProto, + FastAppendBFloat16ArrayToTensorProto, np.float16: _MediumAppendFloat16ArrayToTensorProto, np.float32: |