aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-25 17:34:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-25 17:34:36 -0700
commitc8eaae87533b460928f0141e14196545704fc47e (patch)
tree74191124397f612c62ecef03d0ef1e628a697fa1
parent70d9a489537a5a3c5fe85e75dd52bdc479966992 (diff)
parent535fa4919c9e247e2df673d8af874c3a39a38976 (diff)
Merge pull request #19533 from yongtang:19180-bfloat16
PiperOrigin-RevId: 206083795
-rw-r--r--tensorflow/python/framework/fast_tensor_util.pyx7
-rw-r--r--tensorflow/python/framework/tensor_util.py8
2 files changed, 14 insertions, 1 deletions
diff --git a/tensorflow/python/framework/fast_tensor_util.pyx b/tensorflow/python/framework/fast_tensor_util.pyx
index 17d112a1ec..2e3e15f53a 100644
--- a/tensorflow/python/framework/fast_tensor_util.pyx
+++ b/tensorflow/python/framework/fast_tensor_util.pyx
@@ -6,6 +6,13 @@ cimport numpy as np
from tensorflow.python.util import compat
+def AppendBFloat16ArrayToTensorProto(
+ tensor_proto, np.ndarray[np.uint16_t, ndim=1] nparray):
+ cdef long i, n
+ n = nparray.size
+ for i in range(n):
+ tensor_proto.half_val.append(nparray[i])
+
def AppendFloat16ArrayToTensorProto(
# For numpy, npy_half is a typedef for npy_uint16,
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py
index 8c9dfce7cc..9a0f34fad2 100644
--- a/tensorflow/python/framework/tensor_util.py
+++ b/tensorflow/python/framework/tensor_util.py
@@ -67,10 +67,16 @@ def SlowAppendBFloat16ArrayToTensorProto(tensor_proto, proto_values):
[ExtractBitsFromBFloat16(x) for x in proto_values])
+def FastAppendBFloat16ArrayToTensorProto(tensor_proto, proto_values):
+ fast_tensor_util.AppendBFloat16ArrayToTensorProto(
+ tensor_proto, np.asarray(
+ proto_values, dtype=dtypes.bfloat16.as_numpy_dtype).view(np.uint16))
+
+
if _FAST_TENSOR_UTIL_AVAILABLE:
_NP_TO_APPEND_FN = {
dtypes.bfloat16.as_numpy_dtype:
- SlowAppendBFloat16ArrayToTensorProto,
+ FastAppendBFloat16ArrayToTensorProto,
np.float16:
_MediumAppendFloat16ArrayToTensorProto,
np.float32: