aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tpu/python/tpu/tpu_estimator.py')
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py144
1 files changed, 53 insertions, 91 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 07877fcc76..060b3f9129 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -66,7 +66,7 @@ _CROSS_REPLICA_SUM_OP = 'CrossReplicaSum'
_RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY]
# TODO(b/65703635): Flip the value and remove all dead code.
-_WRAP_INPUT_FN_INTO_WHILE_LOOP = False
+_WRAP_INPUT_FN_INTO_WHILE_LOOP = True
def _create_global_step(graph):
@@ -232,10 +232,8 @@ class _TPUContext(object):
mode == model_fn_lib.ModeKeys.TRAIN
else self._eval_batch_size)
# On TPU
- if self.is_input_sharded_per_core():
- return global_batch_size // self.num_cores
- else:
- return global_batch_size // self.num_hosts
+ return (global_batch_size // self.num_cores
+ if self.is_input_sharded_per_core() else global_batch_size)
@property
def batch_size_for_model_fn(self):
@@ -537,15 +535,13 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
session, self._dequeue_ops)
def before_run(self, run_context):
- iterations = run_context.session.run(self._iterations_per_loop_var)
-
- logging.info('Enqueue next (%d) batch(es) of data to infeed.', iterations)
+ logging.info('Enqueue next batch of data to infeed.')
+ iterations = run_context.session.run(self._iterations_per_loop_var)
self._infeed_thd_controller.send_next_batch_signal(iterations)
if self._dequeue_ops is not None:
# TODO(xiejw): Refactor the outfeed dequeue into tf.while_loop.
- logging.info(
- 'Dequeue next (%d) batch(es) of data from outfeed.', iterations)
+ logging.info('Dequeue next batch of data from outfeed.')
self._outfeed_thd_controller.send_next_batch_signal(iterations)
def end(self, session):
@@ -684,40 +680,6 @@ def generate_per_core_enqueue_ops_fn_for_host(
return enqueue_ops_fn, (lambda: infeed_queue_holder['instance'])
-def generate_per_host_enqueue_ops_fn_for_host(
- ctx, input_fn, inputs_structure_recorder, batch_axis, device):
- """Generates infeed enqueue ops for per-host input_fn on a single host."""
- infeed_queue_holder = {'instance': None}
-
- def enqueue_ops_fn():
- with ops.device(device):
- num_cores_per_host = ctx.num_of_cores_per_host
- inputs = input_fn()
- if isinstance(inputs, tuple):
- features, labels = inputs
- else:
- features, labels = inputs, None
- inputs_structure_recorder.validate_and_record_structure(
- features, labels)
- unsharded_tensor_list = (
- inputs_structure_recorder.flatten_features_and_labels(
- features, labels))
-
- infeed_queue = tpu_feed.InfeedQueue(
- tuple_types=[t.dtype for t in unsharded_tensor_list],
- tuple_shapes=[t.shape for t in unsharded_tensor_list],
- shard_dimensions=batch_axis)
- infeed_queue_holder['instance'] = infeed_queue
- infeed_queue.set_number_of_shards(num_cores_per_host)
-
- per_host_enqueue_ops = (
- infeed_queue.split_inputs_and_generate_enqueue_ops(
- unsharded_tensor_list,
- placement_function=lambda x: device))
- return per_host_enqueue_ops
- return enqueue_ops_fn, (lambda: infeed_queue_holder['instance'])
-
-
class _InputPipeline(object):
"""`_InputPipeline` handles invoking `input_fn` and piping to infeed queue.
@@ -880,8 +842,6 @@ class _InputPipeline(object):
# structure is recorded.
enqueue_ops = self._invoke_input_fn_and_record_structure()
- self._validate_input_pipeline()
-
def dequeue_fn():
"""dequeue_fn is used by TPU to retrieve the tensors."""
values = self._infeed_queue.generate_dequeue_op()
@@ -892,15 +852,15 @@ class _InputPipeline(object):
return (enqueue_ops, dequeue_fn)
def _invoke_input_fn_and_record_structure(self):
- """Deploys the input pipeline and record input structure."""
- enqueue_ops = []
- infeed_queues = []
- num_hosts = self._ctx.num_hosts
- tpu_host_placement_fn = self._ctx.tpu_host_placement_function
if self._sharded_per_core:
# Per-Core input pipeline deployment.
+ tpu_host_placement_fn = self._ctx.tpu_host_placement_function
+ enqueue_ops = []
+ infeed_queues = []
+
# Invoke input pipeline for each core and placed on the corresponding
# host.
+ num_hosts = self._ctx.num_hosts
for host_id in range(num_hosts):
host_device = tpu_host_placement_fn(host_id=host_id)
with ops.device(host_device):
@@ -917,52 +877,48 @@ class _InputPipeline(object):
# Infeed_queue_getter must be called after enqueue_ops_fn is called.
infeed_queues.append(infeed_queue_getter())
+ # infeed_queue is used to generate dequeue ops. The only thing it uses for
+ # dequeue is dtypes and types. So, any one can be used. Here, grab the
+ # first one.
+ self._infeed_queue = infeed_queues[0]
+ return enqueue_ops
+
else:
- for host_id in range(num_hosts):
- host_device = tpu_host_placement_fn(host_id=host_id)
+ # TODO(b/67051042): Extend this to multi-host support.
+ host_id = 0
+ host_device = self._ctx.tpu_host_placement_function(host_id=host_id)
+ def enqueue_fn():
with ops.device(host_device):
with ops.name_scope('input_pipeline_task%d' % (host_id)):
- enqueue_ops_fn, infeed_queue_getter = (
- generate_per_host_enqueue_ops_fn_for_host(
- self._ctx, self._input_fn, self._inputs_structure_recorder,
- self._batch_axis, host_device))
-
- if _WRAP_INPUT_FN_INTO_WHILE_LOOP:
- enqueue_ops.append(_wrap_computation_in_while_loop(
- device=host_device, op_fn=enqueue_ops_fn))
+ inputs = self._input_fn()
+ if isinstance(inputs, tuple):
+ features, labels = inputs
else:
- enqueue_ops.append(enqueue_ops_fn())
- infeed_queues.append(infeed_queue_getter())
- # infeed_queue is used to generate dequeue ops. The only thing it uses for
- # dequeue is dtypes and types. So, any one can be used. Here, grab the
- # first one.
- self._infeed_queue = infeed_queues[0]
- return enqueue_ops
-
- def _validate_input_pipeline(self):
- # Perform some sanity checks to log user friendly information. We should
- # error out to give users better error message. But, if
- # _WRAP_INPUT_FN_INTO_WHILE_LOOP is False (legacy behavior), we cannot break
- # user code, so, log a warning.
- if ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS):
- err_msg = ('Input pipeline contains one or more QueueRunners. '
- 'These are not supported via TPUEstimator. You must convert '
- 'your input pipeline to use `tf.data` instead (see '
- 'https://www.tensorflow.org/programmers_guide/datasets for '
- 'instructions.')
- if _WRAP_INPUT_FN_INTO_WHILE_LOOP:
- raise RuntimeError(err_msg)
- else:
- logging.warn(err_msg)
- elif ops.get_default_graph().get_collection(ops.GraphKeys.SUMMARIES):
- # Queue Runner has summary Ops by default. So here we use elif to do
- # necessary checks for Dataset input pipeline only.
- err_msg = ('Input pipeline contains `tf.summary` operations. '
- 'These are not currently supported.')
+ features, labels = inputs, None
+ self._inputs_structure_recorder.validate_and_record_structure(
+ features, labels)
+ unsharded_tensor_list = (
+ self._inputs_structure_recorder.flatten_features_and_labels(
+ features, labels))
+
+ self._infeed_queue = tpu_feed.InfeedQueue(
+ tuple_types=[t.dtype for t in unsharded_tensor_list],
+ tuple_shapes=[t.shape for t in unsharded_tensor_list],
+ shard_dimensions=self._batch_axis)
+ self._infeed_queue.set_number_of_shards(self._ctx.num_cores)
+
+ def placement_fn(core_id):
+ return self._ctx.tpu_host_placement_function(core_id=core_id)
+ return (
+ self._infeed_queue.split_inputs_and_generate_enqueue_ops(
+ unsharded_tensor_list,
+ placement_function=placement_fn))
+
if _WRAP_INPUT_FN_INTO_WHILE_LOOP:
- raise RuntimeError(err_msg)
+ return _wrap_computation_in_while_loop(device=host_device,
+ op_fn=enqueue_fn)
else:
- logging.warn(err_msg)
+ return enqueue_fn()
class _ModelFnWrapper(object):
@@ -1440,6 +1396,12 @@ class TPUEstimator(estimator_lib.Estimator):
'eval batch size {} must be divisible by number of shards {}'
.format(eval_batch_size, config.tpu_config.num_shards))
+ if (config.tpu_config.num_shards > 8 and
+ config.tpu_config.per_host_input_for_training):
+ # TODO(b/67051042): Support per_host input pipelines when num_shards > 8
+ raise NotImplementedError(
+ 'Per-host input pipelines only available for num_shards <= 8')
+
# Verifies the model_fn signature according to Estimator framework.
estimator_lib._verify_model_fn_args(model_fn, params) # pylint: disable=protected-access
# We cannot store config and params in this constructor as parent