diff options
author | Igor Ganichev <iga@google.com> | 2018-09-12 13:32:04 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-12 13:37:15 -0700 |
commit | 52d9dbfa8ed7bc8b91f1a1be706cf77314b1c687 (patch) | |
tree | 54c82f0460f003f6f5ba75a9ab081017909cf1da /tensorflow/python/eager | |
parent | 5d1de24583aabeb2cb883ab197ae2b8d5446c565 (diff) |
Use WeakKeyDictionaries for global Keras {graph->...} maps
These globals were holding onto graphs including FuncGraphs, which
held onto captured tensors leaving garbage around.
This change also adds a test to catch garbage like this in the future.
To make the test work, I needed to manually breakup some reference
cycles caused by OrderedDicts. We should probably have a custom impl
of OrderedDict similar to the one in Python3 and avoid these issues.
PiperOrigin-RevId: 212694290
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r-- | tensorflow/python/eager/function_test.py | 47 |
1 files changed, 46 insertions, 1 deletions
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index e6a49b66cf..d2b1d9c8a7 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -25,6 +25,7 @@ import sys import numpy from tensorflow.core.protobuf import config_pb2 +from tensorflow.python import keras from tensorflow.python.data.ops import iterator_ops from tensorflow.python.eager import backprop from tensorflow.python.eager import context @@ -38,6 +39,7 @@ from tensorflow.python.framework import random_seed from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util +from tensorflow.python.keras.engine import training as keras_training from tensorflow.python.layers import convolutional from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops @@ -57,6 +59,21 @@ from tensorflow.python.util import compat from tensorflow.python.util import nest +class MiniModel(keras_training.Model): + """Minimal model for mnist. + + Useful for testing and debugging on slow TPU simulators. + """ + + def __init__(self): + super(MiniModel, self).__init__(name='') + self.fc = keras.layers.Dense(1, name='fc', kernel_initializer='ones', + bias_initializer='ones') + + def call(self, inputs, training=True): + return self.fc(inputs) + + @test_util.with_c_shapes class FunctionTest(test.TestCase): @@ -1005,6 +1022,7 @@ class FunctionTest(test.TestCase): with ops.get_default_graph().as_default(): create_variable() + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testLayerInDefun(self): conv = convolutional.Conv2D( filters=1, @@ -1018,7 +1036,34 @@ class FunctionTest(test.TestCase): x = array_ops.ones([1, 2, 2, 1]) y = model(x) - self.assertAllEqual([[[[4.0]]]], y.numpy()) + + if not context.executing_eagerly(): + self.evaluate(variables.global_variables_initializer()) + + self.assertAllEqual([[[[4.0]]]], self.evaluate(y)) + + # Remove reference cycles in model + test_util.dismantle_polymorphic_function(model) + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testDefunKerasModelCall(self): + model = MiniModel() + model.call = function.defun(model.call) + + x = array_ops.ones([1, 2]) + y = model(x) + + if not context.executing_eagerly(): + self.evaluate(variables.global_variables_initializer()) + + self.assertAllEqual([[3.0]], self.evaluate(y)) + + # Remove reference cycles in defun. + test_util.dismantle_polymorphic_function(model.call) + # Break the reference cycle between the MiniModel and the defun: + # MiniModel --(through its `call` method)--> PolymorphicFunction + # PolymorphicFunction --(instancemethod on MiniModel)--> MiniModel + del model.call # Note: The ConfigProto below unfortunately only configures graph # construction. Eager's configuration is controlled in `__main__`. |