aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2018-07-25 11:20:00 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-25 11:24:22 -0700
commit972920262107f5900abd79611e9432ddc6cd810b (patch)
treeec0adf7736b9e0d1e6d702559bb902cd465946f3
parente99262f6ec4a9ea8921d390945dbf0cd90d58d34 (diff)
Implement __index__ in EagerTensor.
This makes range(t) for an integer, scalar Tensor work in Python 3. PiperOrigin-RevId: 206024230
-rw-r--r--tensorflow/python/eager/ops_test.py4
-rw-r--r--tensorflow/python/framework/ops.py5
2 files changed, 8 insertions, 1 deletions
diff --git a/tensorflow/python/eager/ops_test.py b/tensorflow/python/eager/ops_test.py
index fc76ede4c5..17a090d526 100644
--- a/tensorflow/python/eager/ops_test.py
+++ b/tensorflow/python/eager/ops_test.py
@@ -370,6 +370,10 @@ class OpsTest(test_util.TensorFlowTestCase):
with self.assertRaises(TypeError):
float(x)
+ def testRange(self):
+ x = constant_op.constant(2)
+ self.assertEqual([0, 1], list(range(x)))
+
def testFormatString(self):
x = constant_op.constant(3.1415)
self.assertEqual('3.14', '{:.2f}'.format(x))
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 6a5c44e4d9..0fd028ebf0 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -709,7 +709,7 @@ class _EagerTensorBase(Tensor):
raise ValueError("Resource handles are not convertible to numpy.")
return self._cpu_nograd()._numpy() # pylint: disable=protected-access
- # __int__ and __float__ may copy the tensor to CPU and
+ # __int__, __float__ and __index__ may copy the tensor to CPU and
# only work for scalars; values are cast as per numpy.
def __int__(self):
return int(self.numpy())
@@ -717,6 +717,9 @@ class _EagerTensorBase(Tensor):
def __float__(self):
return float(self.numpy())
+ def __index__(self):
+ return int(self.numpy())
+
def __array__(self, dtype=None):
return np.array(self.numpy(), dtype=dtype)