aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-09-07 12:48:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-07 12:51:30 -0700
commitd644e729caa4071cc2571cf679acac4392117848 (patch)
tree3393aae6fd8ecb01f3e02af9ccefefba77b1430c /tensorflow/python/framework
parentca92311cbdd3cecbb41c3f0012bcab90eef0c26f (diff)
Add PyMemberDef for __dict__ on eager tensors.
This fixes dir() calls on instances of eager tensors so that it correctly accesses the __dict__ of EagerTensorType. Earlier it would fail due to an infinite "loop" in subtype_dict: https://github.com/python/cpython/blob/7e610bcdf128f61b925654e4fa80fbac83537d0e/Objects/typeobject.c#L2145 get_builtin_base_with_dict will return the same type (though I'm not sure this is reasonable behavior given its name). The __dict__ getter for the type is subtype_dict creating an infinite tail recursion. PiperOrigin-RevId: 212020695
Diffstat (limited to 'tensorflow/python/framework')
-rw-r--r--tensorflow/python/framework/test_util.py38
1 files changed, 20 insertions, 18 deletions
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 0925598e33..4bece9e25e 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -465,29 +465,31 @@ def assert_no_new_pyobjects_executing_eagerly(f):
f(self, **kwargs)
gc.collect()
previous_count = len(gc.get_objects())
- collection_sizes_before = {
- collection: len(ops.get_collection(collection))
- for collection in ops.get_default_graph().collections
- }
+ if ops.has_default_graph():
+ collection_sizes_before = {
+ collection: len(ops.get_collection(collection))
+ for collection in ops.get_default_graph().collections
+ }
for _ in range(3):
f(self, **kwargs)
# Note that gc.get_objects misses anything that isn't subject to garbage
# collection (C types). Collections are a common source of leaks, so we
# test for collection sizes explicitly.
- for collection_key in ops.get_default_graph().collections:
- collection = ops.get_collection(collection_key)
- size_before = collection_sizes_before.get(collection_key, 0)
- if len(collection) > size_before:
- raise AssertionError(
- ("Collection %s increased in size from "
- "%d to %d (current items %s).") % (collection_key, size_before,
- len(collection), collection))
- # Make sure our collection checks don't show up as leaked memory by
- # removing references to temporary variables.
- del collection
- del collection_key
- del size_before
- del collection_sizes_before
+ if ops.has_default_graph():
+ for collection_key in ops.get_default_graph().collections:
+ collection = ops.get_collection(collection_key)
+ size_before = collection_sizes_before.get(collection_key, 0)
+ if len(collection) > size_before:
+ raise AssertionError(
+ ("Collection %s increased in size from "
+ "%d to %d (current items %s).") %
+ (collection_key, size_before, len(collection), collection))
+ # Make sure our collection checks don't show up as leaked memory by
+ # removing references to temporary variables.
+ del collection
+ del collection_key
+ del size_before
+ del collection_sizes_before
gc.collect()
# There should be no new Python objects hanging around.
new_count = len(gc.get_objects())