aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/kernel_tests/constant_op_eager_test.py31
-rw-r--r--tensorflow/python/lib/core/py_seq_tensor.cc39
2 files changed, 69 insertions, 1 deletions
diff --git a/tensorflow/python/kernel_tests/constant_op_eager_test.py b/tensorflow/python/kernel_tests/constant_op_eager_test.py
index 8e9d75667d..a0d5557b92 100644
--- a/tensorflow/python/kernel_tests/constant_op_eager_test.py
+++ b/tensorflow/python/kernel_tests/constant_op_eager_test.py
@@ -32,6 +32,9 @@ from tensorflow.python.util import compat
# TODO(josh11b): add tests with lists/tuples, Shape.
+# TODO(ashankar): Collapse with tests in constant_op_test.py and use something
+# like the test_util.run_in_graph_and_eager_modes decorator to confirm
+# equivalence between graph and eager execution.
class ConstantTest(test.TestCase):
def _testCpu(self, x):
@@ -280,6 +283,34 @@ class ConstantTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, None):
constant_op.constant([[1, 2], [3], [4, 5]])
+ # TODO(ashankar): This test fails with graph construction since
+ # tensor_util.make_tensor_proto (invoked from constant_op.constant)
+ # does not handle iterables (it relies on numpy conversion).
+ # For consistency, should graph construction handle Python objects
+ # that implement the sequence protocol (but not numpy conversion),
+ # or should eager execution fail on such sequences?
+ def testCustomSequence(self):
+
+ # This is inspired by how many objects in pandas are implemented:
+ # - They implement the Python sequence protocol
+ # - But may raise a KeyError on __getitem__(self, 0)
+ # See https://github.com/tensorflow/tensorflow/issues/20347
+ class MySeq(object):
+
+ def __getitem__(self, key):
+ if key != 1 and key != 3:
+ raise KeyError(key)
+ return key
+
+ def __len__(self):
+ return 2
+
+ def __iter__(self):
+ l = list([1, 3])
+ return l.__iter__()
+
+ self.assertAllEqual([1, 3], self.evaluate(constant_op.constant(MySeq())))
+
class AsTensorTest(test.TestCase):
diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc
index 386be35ba2..3b4f12ae31 100644
--- a/tensorflow/python/lib/core/py_seq_tensor.cc
+++ b/tensorflow/python/lib/core/py_seq_tensor.cc
@@ -88,6 +88,41 @@ bool IsPyDimension(PyObject* obj) {
return ret;
}
+// Sets *elem to a NEW reference to an element in seq on success.
+// REQUIRES: PySequence_Check(seq) && PySequence_Length(seq) > 0.
+Status SampleElementFromSequence(PyObject* seq, PyObject** elem) {
+ *elem = PySequence_GetItem(seq, 0);
+ if (*elem != nullptr) return Status::OK();
+ // seq may implement the sequence protocol (i.e., implement __getitem__)
+ // but may legitimately not have a 0-th element (__getitem__(self, 0)
+ // raises a KeyError). For example:
+ // seq = pandas.Series([0, 1, 2], index=[2, 4, 6])
+ //
+ // We don't actually care for the element at key 0, any element will do
+ // for inferring the element types. All elements are expected to
+ // have the same type, and this will be validated when converting
+ // to an EagerTensor.
+ PyErr_Clear();
+ Safe_PyObjectPtr iter(PyObject_GetIter(seq));
+ if (PyErr_Occurred()) {
+ return errors::InvalidArgument("Cannot infer dtype of a ",
+ Py_TYPE(seq)->tp_name,
+ " object: ", PyExceptionFetch());
+ }
+ *elem = PyIter_Next(iter.get());
+ if (PyErr_Occurred()) {
+ return errors::InvalidArgument(
+ "Cannot infer dtype of a ", Py_TYPE(seq)->tp_name,
+ " object, as iter(<object>).next() failed: ", PyExceptionFetch());
+ }
+ if (*elem == nullptr) {
+ return errors::InvalidArgument("Cannot infer dtype of a ",
+ Py_TYPE(seq)->tp_name,
+ " object since it is an empty sequence");
+ }
+ return Status::OK();
+}
+
Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) {
std::vector<Safe_PyObjectPtr> refs_to_clean;
while (true) {
@@ -98,7 +133,9 @@ Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) {
auto length = PySequence_Length(obj);
if (length > 0) {
shape->AddDim(length);
- obj = PySequence_GetItem(obj, 0);
+ PyObject* elem = nullptr;
+ TF_RETURN_IF_ERROR(SampleElementFromSequence(obj, &elem));
+ obj = elem;
refs_to_clean.push_back(make_safe(obj));
continue;
} else if (length == 0) {