diff options
Diffstat (limited to 'tensorflow/contrib/session_bundle/gc_test.py')
-rw-r--r-- | tensorflow/contrib/session_bundle/gc_test.py | 115 |
1 files changed, 115 insertions, 0 deletions
diff --git a/tensorflow/contrib/session_bundle/gc_test.py b/tensorflow/contrib/session_bundle/gc_test.py new file mode 100644 index 0000000000..c71cbb9e5f --- /dev/null +++ b/tensorflow/contrib/session_bundle/gc_test.py @@ -0,0 +1,115 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for session_bundle.gc.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import re + + +from six.moves import xrange # pylint: disable=redefined-builtin + +import tensorflow as tf + +from tensorflow.contrib.session_bundle import gc +from tensorflow.python.framework import test_util +from tensorflow.python.platform import gfile + + +def tearDownModule(): + gfile.DeleteRecursively(tf.test.get_temp_dir()) + + +class GcTest(test_util.TensorFlowTestCase): + + def testLargestExportVersions(self): + paths = [gc.Path("/foo", 8), gc.Path("/foo", 9), gc.Path("/foo", 10)] + newest = gc.largest_export_versions(2) + n = newest(paths) + self.assertEquals(n, [gc.Path("/foo", 9), gc.Path("/foo", 10)]) + + def testModExportVersion(self): + paths = [gc.Path("/foo", 4), gc.Path("/foo", 5), gc.Path("/foo", 6), + gc.Path("/foo", 9)] + mod = gc.mod_export_version(2) + self.assertEquals(mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 6)]) + mod = gc.mod_export_version(3) + self.assertEquals(mod(paths), [gc.Path("/foo", 6), gc.Path("/foo", 9)]) + + def testOneOfEveryNExportVersions(self): + paths = [gc.Path("/foo", 0), gc.Path("/foo", 1), gc.Path("/foo", 3), + gc.Path("/foo", 5), gc.Path("/foo", 6), gc.Path("/foo", 7), + gc.Path("/foo", 8), gc.Path("/foo", 33)] + one_of = gc.one_of_every_n_export_versions(3) + self.assertEquals(one_of(paths), + [gc.Path("/foo", 3), gc.Path("/foo", 6), + gc.Path("/foo", 8), gc.Path("/foo", 33)]) + + def testOneOfEveryNExportVersionsZero(self): + # Zero is a special case since it gets rolled into the first interval. + # Test that here. + paths = [gc.Path("/foo", 0), gc.Path("/foo", 4), gc.Path("/foo", 5)] + one_of = gc.one_of_every_n_export_versions(3) + self.assertEquals(one_of(paths), + [gc.Path("/foo", 0), gc.Path("/foo", 5)]) + + def testUnion(self): + paths = [] + for i in xrange(10): + paths.append(gc.Path("/foo", i)) + f = gc.union(gc.largest_export_versions(3), gc.mod_export_version(3)) + self.assertEquals( + f(paths), [gc.Path("/foo", 0), gc.Path("/foo", 3), + gc.Path("/foo", 6), gc.Path("/foo", 7), + gc.Path("/foo", 8), gc.Path("/foo", 9)]) + + def testNegation(self): + paths = [gc.Path("/foo", 4), gc.Path("/foo", 5), gc.Path("/foo", 6), + gc.Path("/foo", 9)] + mod = gc.negation(gc.mod_export_version(2)) + self.assertEquals( + mod(paths), [gc.Path("/foo", 5), gc.Path("/foo", 9)]) + mod = gc.negation(gc.mod_export_version(3)) + self.assertEquals( + mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 5)]) + + def testPathsWithParse(self): + base_dir = os.path.join(tf.test.get_temp_dir(), "paths_parse") + self.assertFalse(gfile.Exists(base_dir)) + for p in xrange(3): + gfile.MakeDirs(os.path.join(base_dir, "%d" % p)) + # add a base_directory to ignore + gfile.MakeDirs(os.path.join(base_dir, "ignore")) + + # create a simple parser that pulls the export_version from the directory. + def parser(path): + match = re.match("^" + base_dir + "/(\\d+)$", path.path) + if not match: + return None + return path._replace(export_version=int(match.group(1))) + + self.assertEquals( + gc.get_paths(base_dir, parser=parser), + [gc.Path(os.path.join(base_dir, "0"), 0), + gc.Path(os.path.join(base_dir, "1"), 1), + gc.Path(os.path.join(base_dir, "2"), 2)]) + + +if __name__ == "__main__": + tf.test.main() |