aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/constant_op_eager_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/constant_op_eager_test.py')
-rw-r--r--tensorflow/python/kernel_tests/constant_op_eager_test.py31
1 files changed, 31 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/constant_op_eager_test.py b/tensorflow/python/kernel_tests/constant_op_eager_test.py
index 8e9d75667d..a0d5557b92 100644
--- a/tensorflow/python/kernel_tests/constant_op_eager_test.py
+++ b/tensorflow/python/kernel_tests/constant_op_eager_test.py
@@ -32,6 +32,9 @@ from tensorflow.python.util import compat
# TODO(josh11b): add tests with lists/tuples, Shape.
+# TODO(ashankar): Collapse with tests in constant_op_test.py and use something
+# like the test_util.run_in_graph_and_eager_modes decorator to confirm
+# equivalence between graph and eager execution.
class ConstantTest(test.TestCase):
def _testCpu(self, x):
@@ -280,6 +283,34 @@ class ConstantTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, None):
constant_op.constant([[1, 2], [3], [4, 5]])
+ # TODO(ashankar): This test fails with graph construction since
+ # tensor_util.make_tensor_proto (invoked from constant_op.constant)
+ # does not handle iterables (it relies on numpy conversion).
+ # For consistency, should graph construction handle Python objects
+ # that implement the sequence protocol (but not numpy conversion),
+ # or should eager execution fail on such sequences?
+ def testCustomSequence(self):
+
+ # This is inspired by how many objects in pandas are implemented:
+ # - They implement the Python sequence protocol
+ # - But may raise a KeyError on __getitem__(self, 0)
+ # See https://github.com/tensorflow/tensorflow/issues/20347
+ class MySeq(object):
+
+ def __getitem__(self, key):
+ if key != 1 and key != 3:
+ raise KeyError(key)
+ return key
+
+ def __len__(self):
+ return 2
+
+ def __iter__(self):
+ l = list([1, 3])
+ return l.__iter__()
+
+ self.assertAllEqual([1, 3], self.evaluate(constant_op.constant(MySeq())))
+
class AsTensorTest(test.TestCase):