aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/learn/python/learn/utils/gc.py
blob: dd4376f0512f8759e05fae49ae9f70f097a969a4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

r"""System for specifying garbage collection (GC) of path based data.

This framework allows for GC of data specified by path names, for example files
on disk.  gc.Path objects each represent a single item stored at a path and may
be a base directory,
  /tmp/exports/0/...
  /tmp/exports/1/...
  ...
or a fully qualified file,
  /tmp/train-1.ckpt
  /tmp/train-2.ckpt
  ...

A gc filter function takes and returns a list of gc.Path items.  Filter
functions are responsible for selecting Path items for preservation or deletion.
Note that functions should always return a sorted list.

For example,
  base_dir = "/tmp"
  # create the directories
  for e in xrange(10):
    os.mkdir("%s/%d" % (base_dir, e), 0o755)

  # create a simple parser that pulls the export_version from the directory
  def parser(path):
    match = re.match("^" + base_dir + "/(\\d+)$", path.path)
    if not match:
      return None
    return path._replace(export_version=int(match.group(1)))

  path_list = gc.get_paths("/tmp", parser)  # contains all ten Paths

  every_fifth = gc.mod_export_version(5)
  print every_fifth(path_list) # shows ["/tmp/0", "/tmp/5"]

  largest_three = gc.largest_export_versions(3)
  print largest_three(all_paths)  # shows ["/tmp/7", "/tmp/8", "/tmp/9"]

  both = gc.union(every_fifth, largest_three)
  print both(all_paths)  # shows ["/tmp/0", "/tmp/5",
                         #        "/tmp/7", "/tmp/8", "/tmp/9"]
  # delete everything not in 'both'
  to_delete = gc.negation(both)
  for p in to_delete(all_paths):
    gfile.DeleteRecursively(p.path)  # deletes:  "/tmp/1", "/tmp/2",
                                     # "/tmp/3", "/tmp/4", "/tmp/6",
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import heapq
import math
import os

from tensorflow.python.platform import gfile

Path = collections.namedtuple('Path', 'path export_version')


def largest_export_versions(n):
  """Creates a filter that keeps the largest n export versions.

  Args:
    n: number of versions to keep.

  Returns:
    A filter function that keeps the n largest paths.
  """
  def keep(paths):
    heap = []
    for idx, path in enumerate(paths):
      if path.export_version is not None:
        heapq.heappush(heap, (path.export_version, idx))
    keepers = [paths[i] for _, i in heapq.nlargest(n, heap)]
    return sorted(keepers)

  return keep


def one_of_every_n_export_versions(n):
  """Creates a filter that keeps one of every n export versions.

  Args:
    n: interval size.

  Returns:
    A filter function that keeps exactly one path from each interval
    [0, n], (n, 2n], (2n, 3n], etc...  If more than one path exists in an
    interval the largest is kept.
  """
  def keep(paths):
    """A filter function that keeps exactly one out of every n paths."""

    keeper_map = {}  # map from interval to largest path seen in that interval
    for p in paths:
      if p.export_version is None:
        # Skip missing export_versions.
        continue
      # Find the interval (with a special case to map export_version = 0 to
      # interval 0.
      interval = math.floor(
          (p.export_version - 1) / n) if p.export_version else 0
      existing = keeper_map.get(interval, None)
      if (not existing) or (existing.export_version < p.export_version):
        keeper_map[interval] = p
    return sorted(keeper_map.values())

  return keep


def mod_export_version(n):
  """Creates a filter that keeps every export that is a multiple of n.

  Args:
    n: step size.

  Returns:
    A filter function that keeps paths where export_version % n == 0.
  """
  def keep(paths):
    keepers = []
    for p in paths:
      if p.export_version % n == 0:
        keepers.append(p)
    return sorted(keepers)
  return keep


def union(lf, rf):
  """Creates a filter that keeps the union of two filters.

  Args:
    lf: first filter
    rf: second filter

  Returns:
    A filter function that keeps the n largest paths.
  """
  def keep(paths):
    l = set(lf(paths))
    r = set(rf(paths))
    return sorted(list(l|r))
  return keep


def negation(f):
  """Negate a filter.

  Args:
    f: filter function to invert

  Returns:
    A filter function that returns the negation of f.
  """
  def keep(paths):
    l = set(paths)
    r = set(f(paths))
    return sorted(list(l-r))
  return keep


def get_paths(base_dir, parser):
  """Gets a list of Paths in a given directory.

  Args:
    base_dir: directory.
    parser: a function which gets the raw Path and can augment it with
      information such as the export_version, or ignore the path by returning
      None.  An example parser may extract the export version from a path
      such as "/tmp/exports/100" an another may extract from a full file
      name such as "/tmp/checkpoint-99.out".

  Returns:
    A list of Paths contained in the base directory with the parsing function
    applied.
    By default the following fields are populated,
      - Path.path
    The parsing function is responsible for populating,
      - Path.export_version
  """
  raw_paths = gfile.ListDirectory(base_dir)
  paths = []
  for r in raw_paths:
    p = parser(Path(os.path.join(base_dir, r), None))
    if p:
      paths.append(p)
  return sorted(paths)