aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/eager/tensor_test.py5
-rw-r--r--tensorflow/python/framework/constant_op.py5
2 files changed, 9 insertions, 1 deletions
diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py
index e31c03c08d..b52bbe44d4 100644
--- a/tensorflow/python/eager/tensor_test.py
+++ b/tensorflow/python/eager/tensor_test.py
@@ -23,6 +23,7 @@ import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.eager import core
from tensorflow.python.eager import test
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
@@ -102,6 +103,10 @@ class TFETensorTest(test_util.TensorFlowTestCase):
t = _create_tensor(n)
self.assertAllEqual([[1, 2], [3, 4]], t)
+ def testConstantDtype(self):
+ self.assertEqual(constant_op.constant(1.0, dtype=np.int64).dtype,
+ dtypes.int64)
+
def testTensorAndNumpyMatrix(self):
expected = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32)
actual = _create_tensor([[1.0, 2.0], [3.0, 4.0]])
diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py
index 34848af53b..d51e142da1 100644
--- a/tensorflow/python/framework/constant_op.py
+++ b/tensorflow/python/framework/constant_op.py
@@ -108,7 +108,10 @@ def convert_to_eager_tensor(value, ctx, dtype=None):
dtype, value.dtype))
return value
if dtype is not None:
- dtype = dtype.as_datatype_enum
+ try:
+ dtype = dtype.as_datatype_enum
+ except AttributeError:
+ dtype = dtypes.as_dtype(dtype).as_datatype_enum
device = ctx.device_name
handle = ctx._handle # pylint: disable=protected-access
if isinstance(value, (float,) + six.integer_types):