diff options
author | 2018-07-25 11:20:00 -0700 | |
---|---|---|
committer | 2018-07-25 11:24:22 -0700 | |
commit | 972920262107f5900abd79611e9432ddc6cd810b (patch) | |
tree | ec0adf7736b9e0d1e6d702559bb902cd465946f3 | |
parent | e99262f6ec4a9ea8921d390945dbf0cd90d58d34 (diff) |
Implement __index__ in EagerTensor.
This makes range(t) for an integer, scalar Tensor work in Python 3.
PiperOrigin-RevId: 206024230
-rw-r--r-- | tensorflow/python/eager/ops_test.py | 4 | ||||
-rw-r--r-- | tensorflow/python/framework/ops.py | 5 |
2 files changed, 8 insertions, 1 deletions
diff --git a/tensorflow/python/eager/ops_test.py b/tensorflow/python/eager/ops_test.py index fc76ede4c5..17a090d526 100644 --- a/tensorflow/python/eager/ops_test.py +++ b/tensorflow/python/eager/ops_test.py @@ -370,6 +370,10 @@ class OpsTest(test_util.TensorFlowTestCase): with self.assertRaises(TypeError): float(x) + def testRange(self): + x = constant_op.constant(2) + self.assertEqual([0, 1], list(range(x))) + def testFormatString(self): x = constant_op.constant(3.1415) self.assertEqual('3.14', '{:.2f}'.format(x)) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 6a5c44e4d9..0fd028ebf0 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -709,7 +709,7 @@ class _EagerTensorBase(Tensor): raise ValueError("Resource handles are not convertible to numpy.") return self._cpu_nograd()._numpy() # pylint: disable=protected-access - # __int__ and __float__ may copy the tensor to CPU and + # __int__, __float__ and __index__ may copy the tensor to CPU and # only work for scalars; values are cast as per numpy. def __int__(self): return int(self.numpy()) @@ -717,6 +717,9 @@ class _EagerTensorBase(Tensor): def __float__(self): return float(self.numpy()) + def __index__(self): + return int(self.numpy()) + def __array__(self, dtype=None): return np.array(self.numpy(), dtype=dtype) |