diff options
-rw-r--r-- | tensorflow/python/eager/tensor_test.py | 5 | ||||
-rw-r--r-- | tensorflow/python/framework/constant_op.py | 5 |
2 files changed, 9 insertions, 1 deletions
diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py index e31c03c08d..b52bbe44d4 100644 --- a/tensorflow/python/eager/tensor_test.py +++ b/tensorflow/python/eager/tensor_test.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.python.eager import context from tensorflow.python.eager import core from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util @@ -102,6 +103,10 @@ class TFETensorTest(test_util.TensorFlowTestCase): t = _create_tensor(n) self.assertAllEqual([[1, 2], [3, 4]], t) + def testConstantDtype(self): + self.assertEqual(constant_op.constant(1.0, dtype=np.int64).dtype, + dtypes.int64) + def testTensorAndNumpyMatrix(self): expected = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32) actual = _create_tensor([[1.0, 2.0], [3.0, 4.0]]) diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py index 34848af53b..d51e142da1 100644 --- a/tensorflow/python/framework/constant_op.py +++ b/tensorflow/python/framework/constant_op.py @@ -108,7 +108,10 @@ def convert_to_eager_tensor(value, ctx, dtype=None): dtype, value.dtype)) return value if dtype is not None: - dtype = dtype.as_datatype_enum + try: + dtype = dtype.as_datatype_enum + except AttributeError: + dtype = dtypes.as_dtype(dtype).as_datatype_enum device = ctx.device_name handle = ctx._handle # pylint: disable=protected-access if isinstance(value, (float,) + six.integer_types): |