diff options
Diffstat (limited to 'tensorflow/contrib/tpu/python/tpu/tpu_estimator.py')
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/tpu_estimator.py | 144 |
1 files changed, 53 insertions, 91 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 07877fcc76..060b3f9129 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -66,7 +66,7 @@ _CROSS_REPLICA_SUM_OP = 'CrossReplicaSum' _RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY] # TODO(b/65703635): Flip the value and remove all dead code. -_WRAP_INPUT_FN_INTO_WHILE_LOOP = False +_WRAP_INPUT_FN_INTO_WHILE_LOOP = True def _create_global_step(graph): @@ -232,10 +232,8 @@ class _TPUContext(object): mode == model_fn_lib.ModeKeys.TRAIN else self._eval_batch_size) # On TPU - if self.is_input_sharded_per_core(): - return global_batch_size // self.num_cores - else: - return global_batch_size // self.num_hosts + return (global_batch_size // self.num_cores + if self.is_input_sharded_per_core() else global_batch_size) @property def batch_size_for_model_fn(self): @@ -537,15 +535,13 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): session, self._dequeue_ops) def before_run(self, run_context): - iterations = run_context.session.run(self._iterations_per_loop_var) - - logging.info('Enqueue next (%d) batch(es) of data to infeed.', iterations) + logging.info('Enqueue next batch of data to infeed.') + iterations = run_context.session.run(self._iterations_per_loop_var) self._infeed_thd_controller.send_next_batch_signal(iterations) if self._dequeue_ops is not None: # TODO(xiejw): Refactor the outfeed dequeue into tf.while_loop. - logging.info( - 'Dequeue next (%d) batch(es) of data from outfeed.', iterations) + logging.info('Dequeue next batch of data from outfeed.') self._outfeed_thd_controller.send_next_batch_signal(iterations) def end(self, session): @@ -684,40 +680,6 @@ def generate_per_core_enqueue_ops_fn_for_host( return enqueue_ops_fn, (lambda: infeed_queue_holder['instance']) -def generate_per_host_enqueue_ops_fn_for_host( - ctx, input_fn, inputs_structure_recorder, batch_axis, device): - """Generates infeed enqueue ops for per-host input_fn on a single host.""" - infeed_queue_holder = {'instance': None} - - def enqueue_ops_fn(): - with ops.device(device): - num_cores_per_host = ctx.num_of_cores_per_host - inputs = input_fn() - if isinstance(inputs, tuple): - features, labels = inputs - else: - features, labels = inputs, None - inputs_structure_recorder.validate_and_record_structure( - features, labels) - unsharded_tensor_list = ( - inputs_structure_recorder.flatten_features_and_labels( - features, labels)) - - infeed_queue = tpu_feed.InfeedQueue( - tuple_types=[t.dtype for t in unsharded_tensor_list], - tuple_shapes=[t.shape for t in unsharded_tensor_list], - shard_dimensions=batch_axis) - infeed_queue_holder['instance'] = infeed_queue - infeed_queue.set_number_of_shards(num_cores_per_host) - - per_host_enqueue_ops = ( - infeed_queue.split_inputs_and_generate_enqueue_ops( - unsharded_tensor_list, - placement_function=lambda x: device)) - return per_host_enqueue_ops - return enqueue_ops_fn, (lambda: infeed_queue_holder['instance']) - - class _InputPipeline(object): """`_InputPipeline` handles invoking `input_fn` and piping to infeed queue. @@ -880,8 +842,6 @@ class _InputPipeline(object): # structure is recorded. enqueue_ops = self._invoke_input_fn_and_record_structure() - self._validate_input_pipeline() - def dequeue_fn(): """dequeue_fn is used by TPU to retrieve the tensors.""" values = self._infeed_queue.generate_dequeue_op() @@ -892,15 +852,15 @@ class _InputPipeline(object): return (enqueue_ops, dequeue_fn) def _invoke_input_fn_and_record_structure(self): - """Deploys the input pipeline and record input structure.""" - enqueue_ops = [] - infeed_queues = [] - num_hosts = self._ctx.num_hosts - tpu_host_placement_fn = self._ctx.tpu_host_placement_function if self._sharded_per_core: # Per-Core input pipeline deployment. + tpu_host_placement_fn = self._ctx.tpu_host_placement_function + enqueue_ops = [] + infeed_queues = [] + # Invoke input pipeline for each core and placed on the corresponding # host. + num_hosts = self._ctx.num_hosts for host_id in range(num_hosts): host_device = tpu_host_placement_fn(host_id=host_id) with ops.device(host_device): @@ -917,52 +877,48 @@ class _InputPipeline(object): # Infeed_queue_getter must be called after enqueue_ops_fn is called. infeed_queues.append(infeed_queue_getter()) + # infeed_queue is used to generate dequeue ops. The only thing it uses for + # dequeue is dtypes and types. So, any one can be used. Here, grab the + # first one. + self._infeed_queue = infeed_queues[0] + return enqueue_ops + else: - for host_id in range(num_hosts): - host_device = tpu_host_placement_fn(host_id=host_id) + # TODO(b/67051042): Extend this to multi-host support. + host_id = 0 + host_device = self._ctx.tpu_host_placement_function(host_id=host_id) + def enqueue_fn(): with ops.device(host_device): with ops.name_scope('input_pipeline_task%d' % (host_id)): - enqueue_ops_fn, infeed_queue_getter = ( - generate_per_host_enqueue_ops_fn_for_host( - self._ctx, self._input_fn, self._inputs_structure_recorder, - self._batch_axis, host_device)) - - if _WRAP_INPUT_FN_INTO_WHILE_LOOP: - enqueue_ops.append(_wrap_computation_in_while_loop( - device=host_device, op_fn=enqueue_ops_fn)) + inputs = self._input_fn() + if isinstance(inputs, tuple): + features, labels = inputs else: - enqueue_ops.append(enqueue_ops_fn()) - infeed_queues.append(infeed_queue_getter()) - # infeed_queue is used to generate dequeue ops. The only thing it uses for - # dequeue is dtypes and types. So, any one can be used. Here, grab the - # first one. - self._infeed_queue = infeed_queues[0] - return enqueue_ops - - def _validate_input_pipeline(self): - # Perform some sanity checks to log user friendly information. We should - # error out to give users better error message. But, if - # _WRAP_INPUT_FN_INTO_WHILE_LOOP is False (legacy behavior), we cannot break - # user code, so, log a warning. - if ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS): - err_msg = ('Input pipeline contains one or more QueueRunners. ' - 'These are not supported via TPUEstimator. You must convert ' - 'your input pipeline to use `tf.data` instead (see ' - 'https://www.tensorflow.org/programmers_guide/datasets for ' - 'instructions.') - if _WRAP_INPUT_FN_INTO_WHILE_LOOP: - raise RuntimeError(err_msg) - else: - logging.warn(err_msg) - elif ops.get_default_graph().get_collection(ops.GraphKeys.SUMMARIES): - # Queue Runner has summary Ops by default. So here we use elif to do - # necessary checks for Dataset input pipeline only. - err_msg = ('Input pipeline contains `tf.summary` operations. ' - 'These are not currently supported.') + features, labels = inputs, None + self._inputs_structure_recorder.validate_and_record_structure( + features, labels) + unsharded_tensor_list = ( + self._inputs_structure_recorder.flatten_features_and_labels( + features, labels)) + + self._infeed_queue = tpu_feed.InfeedQueue( + tuple_types=[t.dtype for t in unsharded_tensor_list], + tuple_shapes=[t.shape for t in unsharded_tensor_list], + shard_dimensions=self._batch_axis) + self._infeed_queue.set_number_of_shards(self._ctx.num_cores) + + def placement_fn(core_id): + return self._ctx.tpu_host_placement_function(core_id=core_id) + return ( + self._infeed_queue.split_inputs_and_generate_enqueue_ops( + unsharded_tensor_list, + placement_function=placement_fn)) + if _WRAP_INPUT_FN_INTO_WHILE_LOOP: - raise RuntimeError(err_msg) + return _wrap_computation_in_while_loop(device=host_device, + op_fn=enqueue_fn) else: - logging.warn(err_msg) + return enqueue_fn() class _ModelFnWrapper(object): @@ -1440,6 +1396,12 @@ class TPUEstimator(estimator_lib.Estimator): 'eval batch size {} must be divisible by number of shards {}' .format(eval_batch_size, config.tpu_config.num_shards)) + if (config.tpu_config.num_shards > 8 and + config.tpu_config.per_host_input_for_training): + # TODO(b/67051042): Support per_host input pipelines when num_shards > 8 + raise NotImplementedError( + 'Per-host input pipelines only available for num_shards <= 8') + # Verifies the model_fn signature according to Estimator framework. estimator_lib._verify_model_fn_args(model_fn, params) # pylint: disable=protected-access # We cannot store config and params in this constructor as parent |