diff options
author | 2018-05-24 16:53:33 -0700 | |
---|---|---|
committer | 2018-05-24 16:56:09 -0700 | |
commit | 9532bbbf994df5de1fa6550f3cf9f4dc08fcd640 (patch) | |
tree | 8fc557a0e618de844f088882fd3c29f622c06c13 /tensorflow/python/lib | |
parent | b9b93e90eb523712747b5cb2f70c738115d5f626 (diff) |
When converting a numpy float64 to an EagerTensor, always ensure that it
becomes a float64 tensor.
Earlier py_seq_tensor would fall back to a float32 if not explicitly requesting
a float64 (which would not happen if we had no other information).
PiperOrigin-RevId: 197977260
Diffstat (limited to 'tensorflow/python/lib')
-rw-r--r-- | tensorflow/python/lib/core/py_seq_tensor.cc | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc index 32ea737a99..386be35ba2 100644 --- a/tensorflow/python/lib/core/py_seq_tensor.cc +++ b/tensorflow/python/lib/core/py_seq_tensor.cc @@ -51,6 +51,10 @@ bool IsPyInt(PyObject* obj) { #endif } +bool IsPyDouble(PyObject* obj) { + return PyIsInstance(obj, &PyDoubleArrType_Type); // NumPy double type. +} + bool IsPyFloat(PyObject* obj) { return PyFloat_Check(obj) || PyIsInstance(obj, &PyFloatingArrType_Type); // NumPy float types @@ -113,8 +117,10 @@ Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) { "Attempted to convert an invalid sequence to a Tensor."); } } - } else if (IsPyFloat(obj)) { + } else if (IsPyDouble(obj)) { *dtype = DT_DOUBLE; + } else if (IsPyFloat(obj)) { + *dtype = DT_FLOAT; } else if (PyBool_Check(obj) || PyIsInstance(obj, &PyBoolArrType_Type)) { // Have to test for bool before int, since IsInt(True/False) == true. *dtype = DT_BOOL; @@ -433,7 +439,7 @@ Status PySeqToTensor(PyObject* obj, PyObject* dtype, Tensor* ret) { break; } switch (infer_dtype) { - case DT_DOUBLE: + case DT_FLOAT: // TODO(josh11b): Handle mixed floats and complex numbers? if (requested_dtype == DT_INVALID) { // TensorFlow uses float32s to represent floating point numbers @@ -446,7 +452,8 @@ Status PySeqToTensor(PyObject* obj, PyObject* dtype, Tensor* ret) { // final type. RETURN_STRING_AS_STATUS(ConvertDouble(obj, shape, ret)); } - + case DT_DOUBLE: + RETURN_STRING_AS_STATUS(ConvertDouble(obj, shape, ret)); case DT_INT64: if (requested_dtype == DT_INVALID) { const char* error = ConvertInt32(obj, shape, ret); |