diff options
Diffstat (limited to 'tensorflow/contrib/session_bundle/gc.py')
-rw-r--r-- | tensorflow/contrib/session_bundle/gc.py | 204 |
1 files changed, 204 insertions, 0 deletions
diff --git a/tensorflow/contrib/session_bundle/gc.py b/tensorflow/contrib/session_bundle/gc.py new file mode 100644 index 0000000000..ad7389d96f --- /dev/null +++ b/tensorflow/contrib/session_bundle/gc.py @@ -0,0 +1,204 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""System for specifying garbage collection (GC) of path based data. + +This framework allows for GC of data specified by path names, for example files +on disk. gc.Path objects each represent a single item stored at a path and may +be a base directory, + /tmp/exports/0/... + /tmp/exports/1/... + ... +or a fully qualified file, + /tmp/train-1.ckpt + /tmp/train-2.ckpt + ... + +A gc filter function takes and returns a list of gc.Path items. Filter +functions are responsible for selecting Path items for preservation or deletion. +Note that functions should always return a sorted list. + +For example, + base_dir = "/tmp" + # create the directories + for e in xrange(10): + os.mkdir("%s/%d" % (base_dir, e), 0o755) + + # create a simple parser that pulls the export_version from the directory + def parser(path): + match = re.match("^" + base_dir + "/(\\d+)$", path.path) + if not match: + return None + return path._replace(export_version=int(match.group(1))) + + path_list = gc.get_paths("/tmp", parser) # contains all ten Paths + + every_fifth = gc.mod_export_version(5) + print every_fifth(path_list) # shows ["/tmp/0", "/tmp/5"] + + largest_three = gc.largest_export_versions(3) + print largest_three(all_paths) # shows ["/tmp/7", "/tmp/8", "/tmp/9"] + + both = gc.union(every_fifth, largest_three) + print both(all_paths) # shows ["/tmp/0", "/tmp/5", + # "/tmp/7", "/tmp/8", "/tmp/9"] + # delete everything not in 'both' + to_delete = gc.negation(both) + for p in to_delete(all_paths): + gfile.DeleteRecursively(p.path) # deletes: "/tmp/1", "/tmp/2", + # "/tmp/3", "/tmp/4", "/tmp/6", +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import heapq +import math +import os + +from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.python.platform import gfile + +Path = collections.namedtuple('Path', 'path export_version') + + +def largest_export_versions(n): + """Creates a filter that keeps the largest n export versions. + + Args: + n: number of versions to keep. + + Returns: + A filter function that keeps the n largest paths. + """ + def keep(paths): + heap = [] + for idx, path in enumerate(paths): + if path.export_version: + heapq.heappush(heap, (path.export_version, idx)) + keepers = [paths[i] for _, i in heapq.nlargest(n, heap)] + return sorted(keepers) + + return keep + + +def one_of_every_n_export_versions(n): + """Creates a filter that keeps one of every n export versions. + + Args: + n: interval size. + + Returns: + A filter function that keeps exactly one path from each interval + [0, n], (n, 2n], (2n, 3n], etc... If more than one path exists in an + interval the largest is kept. + """ + def keep(paths): + keeper_map = {} # map from interval to largest path seen in that interval + for p in paths: + if p.export_version is None: + # Skip missing export_versions. + continue + # Find the interval (with a special case to map export_version = 0 to + # interval 0. + interval = math.floor( + (p.export_version - 1) / n) if p.export_version else 0 + existing = keeper_map.get(interval, None) + if (not existing) or (existing.export_version < p.export_version): + keeper_map[interval] = p + return sorted(keeper_map.values()) + + return keep + + +def mod_export_version(n): + """Creates a filter that keeps every export that is a multiple of n. + + Args: + n: step size. + + Returns: + A filter function that keeps paths where export_version % n == 0. + """ + def keep(paths): + keepers = [] + for p in paths: + if p.export_version % n == 0: + keepers.append(p) + return sorted(keepers) + return keep + + +def union(lf, rf): + """Creates a filter that keeps the union of two filters. + + Args: + lf: first filter + rf: second filter + + Returns: + A filter function that keeps the n largest paths. + """ + def keep(paths): + l = set(lf(paths)) + r = set(rf(paths)) + return sorted(list(l|r)) + return keep + + +def negation(f): + """Negate a filter. + + Args: + f: filter function to invert + + Returns: + A filter function that returns the negation of f. + """ + def keep(paths): + l = set(paths) + r = set(f(paths)) + return sorted(list(l-r)) + return keep + + +def get_paths(base_dir, parser): + """Gets a list of Paths in a given directory. + + Args: + base_dir: directory. + parser: a function which gets the raw Path and can augment it with + information such as the export_version, or ignore the path by returning + None. An example parser may extract the export version from a path + such as "/tmp/exports/100" an another may extract from a full file + name such as "/tmp/checkpoint-99.out". + + Returns: + A list of Paths contained in the base directory with the parsing function + applied. + By default the following fields are populated, + - Path.path + The parsing function is responsible for populating, + - Path.export_version + """ + raw_paths = gfile.ListDirectory(base_dir) + paths = [] + for r in raw_paths: + p = parser(Path(os.path.join(base_dir, r), None)) + if p: + paths.append(p) + return sorted(paths) |