aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/lib
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-05-24 16:53:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-24 16:56:09 -0700
commit9532bbbf994df5de1fa6550f3cf9f4dc08fcd640 (patch)
tree8fc557a0e618de844f088882fd3c29f622c06c13 /tensorflow/python/lib
parentb9b93e90eb523712747b5cb2f70c738115d5f626 (diff)
When converting a numpy float64 to an EagerTensor, always ensure that it
becomes a float64 tensor. Earlier py_seq_tensor would fall back to a float32 if not explicitly requesting a float64 (which would not happen if we had no other information). PiperOrigin-RevId: 197977260
Diffstat (limited to 'tensorflow/python/lib')
-rw-r--r--tensorflow/python/lib/core/py_seq_tensor.cc13
1 files changed, 10 insertions, 3 deletions
diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc
index 32ea737a99..386be35ba2 100644
--- a/tensorflow/python/lib/core/py_seq_tensor.cc
+++ b/tensorflow/python/lib/core/py_seq_tensor.cc
@@ -51,6 +51,10 @@ bool IsPyInt(PyObject* obj) {
#endif
}
+bool IsPyDouble(PyObject* obj) {
+ return PyIsInstance(obj, &PyDoubleArrType_Type); // NumPy double type.
+}
+
bool IsPyFloat(PyObject* obj) {
return PyFloat_Check(obj) ||
PyIsInstance(obj, &PyFloatingArrType_Type); // NumPy float types
@@ -113,8 +117,10 @@ Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) {
"Attempted to convert an invalid sequence to a Tensor.");
}
}
- } else if (IsPyFloat(obj)) {
+ } else if (IsPyDouble(obj)) {
*dtype = DT_DOUBLE;
+ } else if (IsPyFloat(obj)) {
+ *dtype = DT_FLOAT;
} else if (PyBool_Check(obj) || PyIsInstance(obj, &PyBoolArrType_Type)) {
// Have to test for bool before int, since IsInt(True/False) == true.
*dtype = DT_BOOL;
@@ -433,7 +439,7 @@ Status PySeqToTensor(PyObject* obj, PyObject* dtype, Tensor* ret) {
break;
}
switch (infer_dtype) {
- case DT_DOUBLE:
+ case DT_FLOAT:
// TODO(josh11b): Handle mixed floats and complex numbers?
if (requested_dtype == DT_INVALID) {
// TensorFlow uses float32s to represent floating point numbers
@@ -446,7 +452,8 @@ Status PySeqToTensor(PyObject* obj, PyObject* dtype, Tensor* ret) {
// final type.
RETURN_STRING_AS_STATUS(ConvertDouble(obj, shape, ret));
}
-
+ case DT_DOUBLE:
+ RETURN_STRING_AS_STATUS(ConvertDouble(obj, shape, ret));
case DT_INT64:
if (requested_dtype == DT_INVALID) {
const char* error = ConvertInt32(obj, shape, ret);