aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2018-09-12 13:32:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-12 13:37:15 -0700
commit52d9dbfa8ed7bc8b91f1a1be706cf77314b1c687 (patch)
tree54c82f0460f003f6f5ba75a9ab081017909cf1da /tensorflow/python/eager
parent5d1de24583aabeb2cb883ab197ae2b8d5446c565 (diff)
Use WeakKeyDictionaries for global Keras {graph->...} maps
These globals were holding onto graphs including FuncGraphs, which held onto captured tensors leaving garbage around. This change also adds a test to catch garbage like this in the future. To make the test work, I needed to manually breakup some reference cycles caused by OrderedDicts. We should probably have a custom impl of OrderedDict similar to the one in Python3 and avoid these issues. PiperOrigin-RevId: 212694290
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r--tensorflow/python/eager/function_test.py47
1 files changed, 46 insertions, 1 deletions
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index e6a49b66cf..d2b1d9c8a7 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -25,6 +25,7 @@ import sys
import numpy
from tensorflow.core.protobuf import config_pb2
+from tensorflow.python import keras
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
@@ -38,6 +39,7 @@ from tensorflow.python.framework import random_seed
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
+from tensorflow.python.keras.engine import training as keras_training
from tensorflow.python.layers import convolutional
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
@@ -57,6 +59,21 @@ from tensorflow.python.util import compat
from tensorflow.python.util import nest
+class MiniModel(keras_training.Model):
+ """Minimal model for mnist.
+
+ Useful for testing and debugging on slow TPU simulators.
+ """
+
+ def __init__(self):
+ super(MiniModel, self).__init__(name='')
+ self.fc = keras.layers.Dense(1, name='fc', kernel_initializer='ones',
+ bias_initializer='ones')
+
+ def call(self, inputs, training=True):
+ return self.fc(inputs)
+
+
@test_util.with_c_shapes
class FunctionTest(test.TestCase):
@@ -1005,6 +1022,7 @@ class FunctionTest(test.TestCase):
with ops.get_default_graph().as_default():
create_variable()
+ @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
def testLayerInDefun(self):
conv = convolutional.Conv2D(
filters=1,
@@ -1018,7 +1036,34 @@ class FunctionTest(test.TestCase):
x = array_ops.ones([1, 2, 2, 1])
y = model(x)
- self.assertAllEqual([[[[4.0]]]], y.numpy())
+
+ if not context.executing_eagerly():
+ self.evaluate(variables.global_variables_initializer())
+
+ self.assertAllEqual([[[[4.0]]]], self.evaluate(y))
+
+ # Remove reference cycles in model
+ test_util.dismantle_polymorphic_function(model)
+
+ @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
+ def testDefunKerasModelCall(self):
+ model = MiniModel()
+ model.call = function.defun(model.call)
+
+ x = array_ops.ones([1, 2])
+ y = model(x)
+
+ if not context.executing_eagerly():
+ self.evaluate(variables.global_variables_initializer())
+
+ self.assertAllEqual([[3.0]], self.evaluate(y))
+
+ # Remove reference cycles in defun.
+ test_util.dismantle_polymorphic_function(model.call)
+ # Break the reference cycle between the MiniModel and the defun:
+ # MiniModel --(through its `call` method)--> PolymorphicFunction
+ # PolymorphicFunction --(instancemethod on MiniModel)--> MiniModel
+ del model.call
# Note: The ConfigProto below unfortunately only configures graph
# construction. Eager's configuration is controlled in `__main__`.