aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2017-10-23 15:59:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-23 16:03:26 -0700
commit57023e1b6c7d27e76d6f41d80b95402b9d93c467 (patch)
tree8685d326134f64c0c08104182c3ccd9ec21ecd82
parentfd182dd02e431e7a7f16bd0ad1547405e591bc82 (diff)
tf.constant takes numpy dtypes in eager mode as well
PiperOrigin-RevId: 173184568
-rw-r--r--tensorflow/python/eager/tensor_test.py5
-rw-r--r--tensorflow/python/framework/constant_op.py5
2 files changed, 9 insertions, 1 deletions
diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py
index e31c03c08d..b52bbe44d4 100644
--- a/tensorflow/python/eager/tensor_test.py
+++ b/tensorflow/python/eager/tensor_test.py
@@ -23,6 +23,7 @@ import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.eager import core
from tensorflow.python.eager import test
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
@@ -102,6 +103,10 @@ class TFETensorTest(test_util.TensorFlowTestCase):
t = _create_tensor(n)
self.assertAllEqual([[1, 2], [3, 4]], t)
+ def testConstantDtype(self):
+ self.assertEqual(constant_op.constant(1.0, dtype=np.int64).dtype,
+ dtypes.int64)
+
def testTensorAndNumpyMatrix(self):
expected = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32)
actual = _create_tensor([[1.0, 2.0], [3.0, 4.0]])
diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py
index 34848af53b..d51e142da1 100644
--- a/tensorflow/python/framework/constant_op.py
+++ b/tensorflow/python/framework/constant_op.py
@@ -108,7 +108,10 @@ def convert_to_eager_tensor(value, ctx, dtype=None):
dtype, value.dtype))
return value
if dtype is not None:
- dtype = dtype.as_datatype_enum
+ try:
+ dtype = dtype.as_datatype_enum
+ except AttributeError:
+ dtype = dtypes.as_dtype(dtype).as_datatype_enum
device = ctx.device_name
handle = ctx._handle # pylint: disable=protected-access
if isinstance(value, (float,) + six.integer_types):