aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/lib
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-09-05 14:58:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-05 15:02:23 -0700
commitb84d27deb8c13eb426951dca6656de2f333f13d5 (patch)
treec65b05ddfa1d74b8acb3e384832b9d5ab12d57af /tensorflow/python/lib
parent9a343a2be2469442ea6bb87f23fc043e1d14cc3b (diff)
Support converting eager tensor to tf.float16 if a numpy half is passed.
This still defaults to float32 for all normal floats. PiperOrigin-RevId: 211704918
Diffstat (limited to 'tensorflow/python/lib')
-rw-r--r--tensorflow/python/lib/core/py_seq_tensor.cc25
1 files changed, 25 insertions, 0 deletions
diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc
index 3b4f12ae31..269142a7c2 100644
--- a/tensorflow/python/lib/core/py_seq_tensor.cc
+++ b/tensorflow/python/lib/core/py_seq_tensor.cc
@@ -55,6 +55,10 @@ bool IsPyDouble(PyObject* obj) {
return PyIsInstance(obj, &PyDoubleArrType_Type); // NumPy double type.
}
+bool IsNumpyHalf(PyObject* obj) {
+ return PyIsInstance(obj, &PyHalfArrType_Type);
+}
+
bool IsPyFloat(PyObject* obj) {
return PyFloat_Check(obj) ||
PyIsInstance(obj, &PyFloatingArrType_Type); // NumPy float types
@@ -156,6 +160,8 @@ Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) {
}
} else if (IsPyDouble(obj)) {
*dtype = DT_DOUBLE;
+ } else if (IsNumpyHalf(obj)) {
+ *dtype = DT_HALF;
} else if (IsPyFloat(obj)) {
*dtype = DT_FLOAT;
} else if (PyBool_Check(obj) || PyIsInstance(obj, &PyBoolArrType_Type)) {
@@ -357,6 +363,17 @@ const char* ConvertOneFloat(PyObject* v, T* out) {
DEFINE_HELPER(ConvertDouble, double, DT_DOUBLE, ConvertOneFloat<double>);
DEFINE_HELPER(ConvertFloat, float, DT_FLOAT, ConvertOneFloat<float>);
+const char* ConvertOneNumpyHalf(PyObject* v, Eigen::half* out) {
+ // NOTE(nareshmodi): Is there a way to convert to C double without the
+ // intermediate Python double? This will help with ConvertOneFloat as well.
+ Safe_PyObjectPtr as_float = make_safe(PyNumber_Float(v));
+ double v_double = PyFloat_AS_DOUBLE(as_float.get());
+ *out = Eigen::half(v_double);
+
+ return nullptr;
+}
+DEFINE_HELPER(ConvertNumpyHalf, Eigen::half, DT_HALF, ConvertOneNumpyHalf);
+
// String support
const char* ConvertOneString(PyObject* v, string* out) {
@@ -452,6 +469,9 @@ Status PySeqToTensor(PyObject* obj, PyObject* dtype, Tensor* ret) {
if (ConvertDouble(obj, shape, ret) == nullptr) return Status::OK();
break;
+ case DT_HALF:
+ RETURN_STRING_AS_STATUS(ConvertNumpyHalf(obj, shape, ret));
+
case DT_INT64:
if (ConvertInt64(obj, shape, ret) == nullptr) return Status::OK();
break;
@@ -489,8 +509,13 @@ Status PySeqToTensor(PyObject* obj, PyObject* dtype, Tensor* ret) {
// final type.
RETURN_STRING_AS_STATUS(ConvertDouble(obj, shape, ret));
}
+
case DT_DOUBLE:
RETURN_STRING_AS_STATUS(ConvertDouble(obj, shape, ret));
+
+ case DT_HALF:
+ RETURN_STRING_AS_STATUS(ConvertNumpyHalf(obj, shape, ret));
+
case DT_INT64:
if (requested_dtype == DT_INVALID) {
const char* error = ConvertInt32(obj, shape, ret);