aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/partitioned_variables.py
blob: 174cabdf8027e75c780441d06a98a24c19be0cfc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Helper functions for creating partitioned variables.

This is a convenient abstraction to partition a large variable across
multiple smaller variables that can be assigned to different devices.

The full variable can be reconstructed by concatenating the smaller variables.
Using partitioned variables instead of a single variable is mostly a
performance choice.  It however also has an impact on:

1. Random initialization, as the random number generator is called once per
   slice
2. Updates, as they happen in parallel across slices

A key design goal is to allow a different graph to repartition a variable
with the same name but different slicings, including possibly no partitions.

TODO(touts): If an initializer provides a seed, the seed must be changed
deterministically for each slice, maybe by adding one to it, otherwise each
slice will use the same values.  Maybe this can be done by passing the
slice offsets to the initializer functions.

Typical usage:

```python
# Create a list of partitioned variables with:
vs = create_partitioned_variables(
    <shape>, <slicing>, <initializer>, name=<optional-name>)

# Pass the list as inputs to embedding_lookup for sharded, parallel lookup:
y = embedding_lookup(vs, ids, partition_strategy="div")

# Or fetch the variables in parallel to speed up large matmuls:
z = matmul(x, concat(slice_dim, vs))
```
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math

from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import tf_export

__all__ = [
    "create_partitioned_variables",
    "variable_axis_size_partitioner",
    "min_max_variable_partitioner",
    "fixed_size_partitioner",
]


@tf_export("variable_axis_size_partitioner")
def variable_axis_size_partitioner(
    max_shard_bytes, axis=0, bytes_per_string_element=16, max_shards=None):
  """Get a partitioner for VariableScope to keep shards below `max_shard_bytes`.

  This partitioner will shard a Variable along one axis, attempting to keep
  the maximum shard size below `max_shard_bytes`.  In practice, this is not
  always possible when sharding along only one axis.  When this happens,
  this axis is sharded as much as possible (i.e., every dimension becomes
  a separate shard).

  If the partitioner hits the `max_shards` limit, then each shard may end up
  larger than `max_shard_bytes`. By default `max_shards` equals `None` and no
  limit on the number of shards is enforced.

  One reasonable value for `max_shard_bytes` is `(64 << 20) - 1`, or almost
  `64MB`, to keep below the protobuf byte limit.

  Args:
    max_shard_bytes: The maximum size any given shard is allowed to be.
    axis: The axis to partition along.  Default: outermost axis.
    bytes_per_string_element: If the `Variable` is of type string, this provides
      an estimate of how large each scalar in the `Variable` is.
    max_shards: The maximum number of shards in int created taking precedence
      over `max_shard_bytes`.

  Returns:
    A partition function usable as the `partitioner` argument to
    `variable_scope`, `get_variable`, and `get_partitioned_variable_list`.

  Raises:
    ValueError: If any of the byte counts are non-positive.
  """
  if max_shard_bytes < 1 or bytes_per_string_element < 1:
    raise ValueError(
        "Both max_shard_bytes and bytes_per_string_element must be positive.")
  if max_shards and max_shards < 1:
    raise ValueError(
        "max_shards must be positive.")

  def _partitioner(shape, dtype):
    """Partitioner that partitions shards to have max_shard_bytes total size.

    Args:
      shape: A `TensorShape`.
      dtype: A `DType`.

    Returns:
      A tuple representing how much to slice each axis in shape.

    Raises:
      ValueError: If shape is not a fully defined `TensorShape` or dtype is not
        a `DType`.
    """
    if not isinstance(shape, tensor_shape.TensorShape):
      raise ValueError("shape is not a TensorShape: %s" % shape)
    if not shape.is_fully_defined():
      raise ValueError("shape is not fully defined: %s" % shape)
    if not isinstance(dtype, dtypes.DType):
      raise ValueError("dtype is not a DType: %s" % dtype)

    if dtype.base_dtype == dtypes.string:
      element_size = bytes_per_string_element
    else:
      element_size = dtype.size

    partitions = [1] * shape.ndims
    bytes_per_slice = 1.0 * (
        shape.num_elements() / shape[axis].value) * element_size
    # How many slices can we fit on one shard of size at most max_shard_bytes?
    # At least one slice is required.
    slices_per_shard = max(1, math.floor(max_shard_bytes / bytes_per_slice))
    # How many shards do we need for axis given that each shard fits
    # slices_per_shard slices from a total of shape[axis].value slices?
    axis_shards = int(math.ceil(1.0 * shape[axis].value / slices_per_shard))
    if max_shards:
      axis_shards = min(max_shards, axis_shards)

    partitions[axis] = axis_shards

    return partitions

  return _partitioner


@tf_export("min_max_variable_partitioner")
def min_max_variable_partitioner(max_partitions=1, axis=0,
                                 min_slice_size=256 << 10,
                                 bytes_per_string_element=16):
  """Partitioner to allocate minimum size per slice.

  Returns a partitioner that partitions the variable of given shape and dtype
  such that each partition has a minimum of `min_slice_size` slice of the
  variable. The maximum number of such partitions (upper bound) is given by
  `max_partitions`.

  Args:
    max_partitions: Upper bound on the number of partitions. Defaults to 1.
    axis: Axis along which to partition the variable. Defaults to 0.
    min_slice_size: Minimum size of the variable slice per partition. Defaults
      to 256K.
    bytes_per_string_element: If the `Variable` is of type string, this provides
      an estimate of how large each scalar in the `Variable` is.

  Returns:
    A partition function usable as the `partitioner` argument to
    `variable_scope`, `get_variable`, and `get_partitioned_variable_list`.

  """
  def _partitioner(shape, dtype):
    """Partitioner that partitions list for a variable of given shape and type.

    Ex: Consider partitioning a variable of type float32 with
      shape=[1024, 1024].
      If `max_partitions` >= 16, this function would return
        [(1024 * 1024 * 4) / (256 * 1024), 1] = [16, 1].
      If `max_partitions` < 16, this function would return
        [`max_partitions`, 1].

    Args:
      shape: Shape of the variable.
      dtype: Type of the variable.

    Returns:
      List of partitions for each axis (currently only one axis can be
      partitioned).

    Raises:
      ValueError: If axis to partition along does not exist for the variable.
    """
    if axis >= len(shape):
      raise ValueError("Can not partition variable along axis %d when shape is "
                       "only %s" % (axis, shape))
    if dtype.base_dtype == dtypes.string:
      bytes_per_element = bytes_per_string_element
    else:
      bytes_per_element = dtype.size
    total_size_bytes = shape.num_elements() * bytes_per_element
    partitions = total_size_bytes / min_slice_size
    partitions_list = [1] * len(shape)
    # We can not partition the variable beyond what its shape or
    # `max_partitions` allows.
    partitions_list[axis] = max(1, min(shape[axis].value,
                                       max_partitions,
                                       int(math.ceil(partitions))))
    return partitions_list
  return _partitioner


@tf_export("fixed_size_partitioner")
def fixed_size_partitioner(num_shards, axis=0):
  """Partitioner to specify a fixed number of shards along given axis.

  Args:
    num_shards: `int`, number of shards to partition variable.
    axis: `int`, axis to partition on.

  Returns:
    A partition function usable as the `partitioner` argument to
    `variable_scope`, `get_variable`, and `get_partitioned_variable_list`.
  """
  def _partitioner(shape, **unused_args):
    partitions_list = [1] * len(shape)
    partitions_list[axis] = min(num_shards, shape[axis].value)
    return partitions_list
  return _partitioner


@tf_export("create_partitioned_variables")
def create_partitioned_variables(
    shape, slicing, initializer, dtype=dtypes.float32,
    trainable=True, collections=None, name=None, reuse=None):
  """Create a list of partitioned variables according to the given `slicing`.

  Currently only one dimension of the full variable can be sliced, and the
  full variable can be reconstructed by the concatenation of the returned
  list along that dimension.

  Args:
    shape: List of integers.  The shape of the full variable.
    slicing: List of integers.  How to partition the variable.
      Must be of the same length as `shape`.  Each value
      indicate how many slices to create in the corresponding
      dimension.  Presently only one of the values can be more than 1;
      that is, the variable can only be sliced along one dimension.

      For convenience, The requested number of partitions does not have to
      divide the corresponding dimension evenly.  If it does not, the
      shapes of the partitions are incremented by 1 starting from partition
      0 until all slack is absorbed.  The adjustment rules may change in the
      future, but as you can save/restore these variables with different
      slicing specifications this should not be a problem.
    initializer: A `Tensor` of shape `shape` or a variable initializer
      function.  If a function, it will be called once for each slice,
      passing the shape and data type of the slice as parameters.  The
      function must return a tensor with the same shape as the slice.
    dtype: Type of the variables. Ignored if `initializer` is a `Tensor`.
    trainable: If True also add all the variables to the graph collection
      `GraphKeys.TRAINABLE_VARIABLES`.
    collections: List of graph collections keys to add the variables to.
      Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
    name: Optional name for the full variable.  Defaults to
      `"PartitionedVariable"` and gets uniquified automatically.
    reuse: Boolean or `None`; if `True` and name is set, it would reuse
      previously created variables. if `False` it will create new variables.
      if `None`, it would inherit the parent scope reuse.

  Returns:
    A list of Variables corresponding to the slicing.

  Raises:
    ValueError: If any of the arguments is malformed.
  """
  logging.warn(
      "create_partitioned_variables is deprecated.  Use "
      "tf.get_variable with a partitioner set, or "
      "tf.get_partitioned_variable_list, instead.")

  if len(shape) != len(slicing):
    raise ValueError("The 'shape' and 'slicing' of a partitioned Variable "
                     "must have the length: shape: %s, slicing: %s" %
                     (shape, slicing))
  if len(shape) < 1:
    raise ValueError("A partitioned Variable must have rank at least 1: "
                     "shape: %s" % shape)

  # Legacy: we are provided the slicing directly, so just pass it to
  # the partitioner.
  partitioner = lambda **unused_kwargs: slicing

  with variable_scope.variable_scope(
      name, "PartitionedVariable", reuse=reuse):
    # pylint: disable=protected-access
    partitioned_var = variable_scope._get_partitioned_variable(
        name=None,
        shape=shape,
        dtype=dtype,
        initializer=initializer,
        trainable=trainable,
        partitioner=partitioner,
        collections=collections)
    return list(partitioned_var)
    # pylint: enable=protected-access