aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
blob: 168726a6b3e096e1b53b49f471abb754527fa6c4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================================================

"""TpuEstimator class."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import copy
import threading
from six.moves import queue as Queue  # pylint: disable=redefined-builtin

from tensorflow.contrib.tpu.python.tpu import tpu
from tensorflow.contrib.tpu.python.tpu import tpu_config
from tensorflow.contrib.tpu.python.tpu import tpu_feed
from tensorflow.contrib.tpu.python.tpu import training_loop

from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training


_BATCH_SIZE_KEY = 'batch_size'


def _tpu_job(run_config):
  # The tpu job is determined by the run_config. Right now, this method is
  # required as tpu_config is not part of the RunConfig.
  return None if run_config.master in ['', 'local'] else 'tpu_worker'


class _SIGNAL(object):
  """Signal used to control the input thread of infeed."""
  NEXT_BATCH = 1
  STOP = 2


class InfeedThreadController(object):
  """This wraps the infeed thread and stops when Estimator train finishes.

  For model_fn wrapper, it is not possible to know when the `train` API will
  stop. It could be the cases that the `max_steps` is reached or some hook
  requests the stop in the monitored_session.

  This controller (with coordination with `TpuInfeedSessionHook`) does the
  following:

  1) It pre-infeeds one `batch` data for current TPU iterations.

  2) When `before_run` of `TpuInfeedSessionHook` is called, one more `batch`
  data will be infed.

  3) When `end` of `TpuInfeedSessionHook` is called, the thread will end
  gracefully.

  So, we might need to adjust the algorithrm here if the IO is slower than the
  computation.
  """

  def __init__(self, session, enqueue_ops, iterations):
    self._signal_queue = Queue.Queue()
    self._input_thd = threading.Thread(target=self._input_thread_fn_for_loading,
                                       args=(session, enqueue_ops, iterations))
    self._input_thd.daemon = True
    self._input_thd.start()

  def _input_thread_fn_for_loading(self, session, enqueue_ops, iterations):
    count = 0
    while True:
      signal = self._signal_queue.get()
      if signal == _SIGNAL.STOP:
        logging.info('Stop Infeed input thread.')
        return

      for i in range(iterations):
        logging.debug('InfeedEnqueue data for iteration (%d, %d)', count, i)
        session.run(enqueue_ops)
      count += 1

  def load_next_batch(self):
    self._signal_queue.put(_SIGNAL.NEXT_BATCH)

  def join(self):
    logging.info('Waiting for InputThread to exit.')
    self._signal_queue.put(_SIGNAL.STOP)
    self._input_thd.join()


class TpuInfeedSessionHook(session_run_hook.SessionRunHook):
  """A Session hook setting up the TPU initialization and infeed.

  This hook does two major things:
  1. initialize and shutdown TPU system (maybe a separated hook)
  2. launch and join the input thread for infeed.
  """

  def __init__(self, run_config, enqueue_fn):
    self._iterations = run_config.tpu_config.iterations_per_loop
    self._enqueue_fn = enqueue_fn
    self._tpu_job = _tpu_job(run_config)

  def begin(self):
    self._enqueue_ops = self._enqueue_fn()
    logging.info('TPU job name %s', self._tpu_job)
    self._init_op = [tpu.initialize_system(job=self._tpu_job)]
    self._finalize_op = [tpu.shutdown_system(job=self._tpu_job)]

  def after_create_session(self, session, coord):
    logging.info('Init TPU system')
    session.run(self._init_op)

    logging.info('Start infeed input thread controller')
    self._infeed_thd_controller = InfeedThreadController(
        session, self._enqueue_ops, self._iterations)

  def before_run(self, run_context):
    logging.info('Load next batch of data to infeed.')
    self._infeed_thd_controller.load_next_batch()

  def end(self, session):
    logging.info('Stop infeed input thread controller')
    self._infeed_thd_controller.join()

    logging.info('Shutdown TPU system.')
    session.run(self._finalize_op)


class _PerShardOutput(object):
  """Wraps input_fn's outputs into per-shard outputs.

  Used so that the wrapped model_fn can distinguish between sharded input and
  unsharded inputs (e.g., for export_savedmodel()).
  """

  def __init__(self, output):
    self.output = output

  def as_list(self):
    return self.output


class TpuEstimator(estimator_lib.Estimator):
  """Estimator with TPU support.

  TpuEstimator handles many of the details of running on TPU devices, such as
  replicating inputs and models for each core, and returning to host
  periodically to run hooks.

  Note: TpuEstimator transforms a global batch size in params to a per-shard
        batch size when calling the input_fn.
  """

  def __init__(self,
               model_fn=None,
               model_dir=None,
               config=None,
               params=None,
               use_tpu=True):
    if config is None or not isinstance(config, tpu_config.RunConfig):
      raise ValueError(
          '`config` must be provided with type `tpu_config.RunConfig`')

    if use_tpu and params is not None and _BATCH_SIZE_KEY in params:
      if not isinstance(params[_BATCH_SIZE_KEY], int):
        raise ValueError(
            '`{}` in params must be an int'.format(_BATCH_SIZE_KEY))
      params = copy.deepcopy(params)
      # The specified batch size is the batch size for the entire computation.
      # The input_fn is called per-shard, so we want to calculate the per-shard
      # batch size and pass that.
      if params[_BATCH_SIZE_KEY] % config.tpu_config.num_shards != 0:
        raise ValueError(
            'batch size {} must be divisible by number of shards {}'
            .format(params[_BATCH_SIZE_KEY], config.tpu_config.num_shards))

    if use_tpu:
      # Verifies the model_fn signature according to Estimator framework.
      estimator_lib._verify_model_fn_args(model_fn, params)  # pylint: disable=protected-access
      # We cannot store config and params in this constructor as parent
      # constructor might change them, such as assigning a temp dir for
      # config.model_dir.
      model_function = wrapped_model_fn(model_fn)
    else:
      model_function = model_fn

    super(TpuEstimator, self).__init__(
        model_fn=model_function,
        model_dir=model_dir,
        config=config,
        params=params)
    self.use_tpu = use_tpu

  def _create_global_step(self, graph):
    """Creates a global step suitable for TPUs.

    Args:
      graph: The graph in which to create the global step.

    Returns:
      A global step `Tensor`.

    Raises:
      ValueError: if the global step tensor is already defined.
    """
    graph = graph or ops.get_default_graph()
    if training.get_global_step(graph) is not None:
      raise ValueError('"global_step" already exists.')
    # Create in proper graph and base name_scope.
    with graph.as_default() as g, g.name_scope(None):
      return variable_scope.get_variable(
          ops.GraphKeys.GLOBAL_STEP,
          shape=[],
          dtype=dtypes.int32,
          initializer=init_ops.zeros_initializer(),
          trainable=False,
          use_resource=True,
          collections=[ops.GraphKeys.GLOBAL_VARIABLES,
                       ops.GraphKeys.GLOBAL_STEP])

  def _call_input_fn(self, input_fn, mode):
    """Calls the input function.

    Args:
      input_fn: The input function.
      mode: ModeKeys

    Returns:
      Either features or (features, labels) where features and labels are:
        features - `Tensor` or dictionary of string feature name to `Tensor`.
        labels - `Tensor` or dictionary of `Tensor` with labels.

    Raises:
      ValueError: if input_fn takes invalid arguments.
    """
    if not self.use_tpu or mode != model_fn_lib.ModeKeys.TRAIN:
      return super(TpuEstimator, self)._call_input_fn(input_fn, mode)

    input_fn_args = estimator_lib._fn_args(input_fn)  # pylint: disable=protected-access
    config = self.config  # a deep copy.
    kwargs = {}
    if 'params' in input_fn_args:
      kwargs['params'] = self.params  # a deep copy.
    if 'config' in input_fn_args:
      kwargs['config'] = config

    # Now for TPU training.
    if 'params' in kwargs and _BATCH_SIZE_KEY in kwargs['params']:
      kwargs['params'][_BATCH_SIZE_KEY] //= config.tpu_config.num_shards

    job = _tpu_job(config)
    def placement_function(index):
      if job is None:
        return '/replica:0/task:0/device:CPU:0'
      else:
        return '/job:%s/replica:0/task:%d/device:CPU:0' % (job, index / 8)

    features = []
    labels = []
    for i in range(config.tpu_config.num_shards):
      with ops.device(placement_function(i)):
        result = input_fn(**kwargs)
        # input_fn may return either features or (features, labels)
        if isinstance(result, tuple):
          features.append(result[0])
          labels.append(result[1])
        else:
          features.append(result)
    if not labels or all(l is None for l in labels):
      return _PerShardOutput(features), None
    return _PerShardOutput(features), _PerShardOutput(labels)


def _verify_estimator_spec(estimator_spec):
  """Validates the estimator_spec."""
  err_msg = '{} returned by EstimatorSpec is not supported in TPUEstimator.'
  if estimator_spec.training_chief_hooks:
    raise ValueError(err_msg.format('training_chief_hooks'))
  if estimator_spec.training_hooks:
    raise ValueError(err_msg.format('training_hooks'))
  return estimator_spec


def _call_model_fn(model_fn, features, labels, mode, config, params):
  """Calls the model_fn with required parameters."""
  model_fn_args = estimator_lib._fn_args(model_fn)  # pylint: disable=protected-access
  kwargs = {}
  if 'mode' in model_fn_args:
    kwargs['mode'] = mode
  if 'params' in model_fn_args:
    kwargs['params'] = params
  if 'config' in model_fn_args:
    kwargs['config'] = config
  return model_fn(features=features, labels=labels, **kwargs)


def _call_model_fn_with_tpu(model_fn, features, labels, mode, config, params):
  """Calls user provided `model_fn` and verifies the estimator_spec."""
  # Makes deep copy with `config` and params` in case user mutates them.
  config = copy.deepcopy(config)
  params = copy.deepcopy(params)
  return _verify_estimator_spec(_call_model_fn(
      model_fn, features, labels, mode, config, params))


def _call_model_fn_without_tpu(
    model_fn, features, labels, mode, config, params):
  # Deepcopy of config and params is not required in this branch.
  return _call_model_fn(model_fn, features, labels, mode, config, params)


# TODO(xiejw): Improve the structure of this input_fn to infeed converion.
# The code now looks not like Estimator style. We need to abstract many
# details.
def _create_infeed_enqueue_ops_and_dequeue_fn(run_config, features, labels):
  """Utility to convert input_fn to enqueue and dequeue fns for TPU.

  Mainly, three things need to be done here.
  1. Calls the input_fn many times (`num_shards`) to infeed the data into TPU
  2. Create a dequeue_fn used by the train_step inside TPU execution to
  dequeue the tensors.
  3. Sets up the input thread to infeed.

  Args:
    run_config: run_config
    features: features
    labels: labels

  Returns:
    A tuple of (dequeue_fn, enqueue_fn)
  """
  infeed_names = None
  sharded_inputs = []
  if isinstance(features[0], dict):
    # We need a fixed ordering for enqueueing and dequeueing.
    infeed_names = [name for name in features[0]]

  for shard in range(run_config.tpu_config.num_shards):
    inputs = []
    if infeed_names is None:
      inputs.append(features[shard])
    else:
      for name in infeed_names:
        inputs.append(features[shard][name])
    if labels is not None:
      inputs.append(labels[shard])
    sharded_inputs.append(inputs)

  infeed_queue = tpu_feed.InfeedQueue(
      number_of_tuple_elements=len(sharded_inputs[0]))
  infeed_queue.set_configuration_from_sharded_input_tensors(sharded_inputs)

  def dequeue_fn():
    """dequeue_fn is used by the train_step in TPU to retrieve the tensors."""
    values = infeed_queue.generate_dequeue_op()

    expected_num_tensors = 0
    if labels is not None:
      expected_num_tensors += 1
    if infeed_names is None:
      expected_num_tensors += 1
    else:
      expected_num_tensors += len(infeed_names)
    assert len(values) == expected_num_tensors

    dequeue_label = None
    if labels is not None:
      dequeue_label = values[-1]
    if infeed_names is None:
      return values[0], dequeue_label
    # Restore the feature dictionary and label.
    dequeued_features = {}
    for i in range(len(infeed_names)):
      dequeued_features[infeed_names[i]] = values[i]
    return dequeued_features, dequeue_label

  def tpu_ordinal_function(index):
    """Return the TPU ordinal associated with a shard.

    Required because the enqueue ops are placed on CPU.

    Args:
      index: the shard index

    Returns:
      The ordinal of the TPU device the shard's infeed should be placed on.
    """
    return index % 8

  def enqueue_fn():
    """enqueue_fn is used to add ops to the graph to send tensors."""
    return infeed_queue.generate_enqueue_ops(
        sharded_inputs, tpu_ordinal_function=tpu_ordinal_function)

  return (dequeue_fn, enqueue_fn)


def wrapped_model_fn(model_fn):
  """Returns a new model_fn, which wraps the TPU support."""

  def _model_fn(features, labels, mode, config, params=None):
    """model_fn."""

    # TODO(jhseu): Move to EVAL and PREDICT to TPU.
    if mode != model_fn_lib.ModeKeys.TRAIN:
      return _call_model_fn_without_tpu(
          model_fn, features, labels, mode, config, params)

    # Now for TPU training.
    if params is not None and _BATCH_SIZE_KEY in params:
      params[_BATCH_SIZE_KEY] //= config.tpu_config.num_shards

    assert isinstance(features, _PerShardOutput)
    features = features.as_list()
    if labels is not None:
      assert isinstance(labels, _PerShardOutput)
      labels = labels.as_list()

    dequeue_fn, enqueue_fn = (
        _create_infeed_enqueue_ops_and_dequeue_fn(config, features, labels))

    loss = _train_on_tpu_shards(
        config,
        train_step=_convert_model_fn_to_train_step(
            model_fn, dequeue_fn, mode, config, params))

    # Gets the variables back from TPU nodes. This means the variables updated
    # by TPU will now be *synced* to host memory.
    update_ops = [
        array_ops.check_numerics(v.read_value(),
                                 'Gradient for %s is NaN' % v.name).op
        for v in variables.trainable_variables()
    ]

    hooks = [
        TpuInfeedSessionHook(config, enqueue_fn),
        training.LoggingTensorHook(
            {'loss': array_ops.identity(loss),
             'step': training.get_global_step()},
            every_n_secs=30)
    ]

    return model_fn_lib.EstimatorSpec(
        mode,
        loss=array_ops.identity(loss),
        training_hooks=hooks,
        train_op=control_flow_ops.group(*update_ops))
  return _model_fn


def _convert_model_fn_to_train_step(model_fn, dequeue_fn, mode, run_config,
                                    params):
  """Generates a train step based on the model_fn."""

  def train_step(loss):
    """Training step function for use inside a while loop."""
    del loss  # unused; required in function signature.
    features, labels = dequeue_fn()

    # TODO(xiejw): how to do we support hook and savers in the original
    # model_fn. Realistically, the original
    # model_fn will be excuted on TPU chips in a replica way. The hooks
    # returned by the model_fn cannot be supported at all. If we have to,
    # the graph construction part in the model_fn should be separated from the
    # control part (such as hooks and savers). By that the graph construction
    # could de defered on TPU chip, while the control logic can stay in host.
    estimator_spec = _call_model_fn_with_tpu(
        model_fn, features, labels, mode, run_config, params)
    loss, train_op = estimator_spec.loss, estimator_spec.train_op
    with ops.control_dependencies([train_op]):
      return array_ops.identity(loss)
  return train_step


def _train_on_tpu_shards(run_config, train_step):
  """Executes the `train_step` on all shards."""
  def train_shard():
    return training_loop.repeat(run_config.tpu_config.iterations_per_loop,
                                train_step,
                                [1e7],  # initial_loss
                                name='loop')

  (loss,) = tpu.shard(train_shard,
                      inputs=[],
                      num_shards=run_config.tpu_config.num_shards,
                      outputs_from_all_shards=False)
  return loss