diff options
author | Akshay Modi <nareshmodi@google.com> | 2018-09-05 14:58:10 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-05 15:02:23 -0700 |
commit | b84d27deb8c13eb426951dca6656de2f333f13d5 (patch) | |
tree | c65b05ddfa1d74b8acb3e384832b9d5ab12d57af /tensorflow/python/lib | |
parent | 9a343a2be2469442ea6bb87f23fc043e1d14cc3b (diff) |
Support converting eager tensor to tf.float16 if a numpy half is passed.
This still defaults to float32 for all normal floats.
PiperOrigin-RevId: 211704918
Diffstat (limited to 'tensorflow/python/lib')
-rw-r--r-- | tensorflow/python/lib/core/py_seq_tensor.cc | 25 |
1 files changed, 25 insertions, 0 deletions
diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc index 3b4f12ae31..269142a7c2 100644 --- a/tensorflow/python/lib/core/py_seq_tensor.cc +++ b/tensorflow/python/lib/core/py_seq_tensor.cc @@ -55,6 +55,10 @@ bool IsPyDouble(PyObject* obj) { return PyIsInstance(obj, &PyDoubleArrType_Type); // NumPy double type. } +bool IsNumpyHalf(PyObject* obj) { + return PyIsInstance(obj, &PyHalfArrType_Type); +} + bool IsPyFloat(PyObject* obj) { return PyFloat_Check(obj) || PyIsInstance(obj, &PyFloatingArrType_Type); // NumPy float types @@ -156,6 +160,8 @@ Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) { } } else if (IsPyDouble(obj)) { *dtype = DT_DOUBLE; + } else if (IsNumpyHalf(obj)) { + *dtype = DT_HALF; } else if (IsPyFloat(obj)) { *dtype = DT_FLOAT; } else if (PyBool_Check(obj) || PyIsInstance(obj, &PyBoolArrType_Type)) { @@ -357,6 +363,17 @@ const char* ConvertOneFloat(PyObject* v, T* out) { DEFINE_HELPER(ConvertDouble, double, DT_DOUBLE, ConvertOneFloat<double>); DEFINE_HELPER(ConvertFloat, float, DT_FLOAT, ConvertOneFloat<float>); +const char* ConvertOneNumpyHalf(PyObject* v, Eigen::half* out) { + // NOTE(nareshmodi): Is there a way to convert to C double without the + // intermediate Python double? This will help with ConvertOneFloat as well. + Safe_PyObjectPtr as_float = make_safe(PyNumber_Float(v)); + double v_double = PyFloat_AS_DOUBLE(as_float.get()); + *out = Eigen::half(v_double); + + return nullptr; +} +DEFINE_HELPER(ConvertNumpyHalf, Eigen::half, DT_HALF, ConvertOneNumpyHalf); + // String support const char* ConvertOneString(PyObject* v, string* out) { @@ -452,6 +469,9 @@ Status PySeqToTensor(PyObject* obj, PyObject* dtype, Tensor* ret) { if (ConvertDouble(obj, shape, ret) == nullptr) return Status::OK(); break; + case DT_HALF: + RETURN_STRING_AS_STATUS(ConvertNumpyHalf(obj, shape, ret)); + case DT_INT64: if (ConvertInt64(obj, shape, ret) == nullptr) return Status::OK(); break; @@ -489,8 +509,13 @@ Status PySeqToTensor(PyObject* obj, PyObject* dtype, Tensor* ret) { // final type. RETURN_STRING_AS_STATUS(ConvertDouble(obj, shape, ret)); } + case DT_DOUBLE: RETURN_STRING_AS_STATUS(ConvertDouble(obj, shape, ret)); + + case DT_HALF: + RETURN_STRING_AS_STATUS(ConvertNumpyHalf(obj, shape, ret)); + case DT_INT64: if (requested_dtype == DT_INVALID) { const char* error = ConvertInt32(obj, shape, ret); |